diff --git a/CMakeLists.txt b/CMakeLists.txt index 4458e219abb3fb0ebb67095626f824ee2faf8b56..c664f43e9e446a08bdcbe844ee7741a86a72660e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm) file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h) include_directories(src/) -set(CMAKE_CXX_FLAGS "-O3 -s -DNDEBUG ${CMAKE_CXX_FLAGS}") +set(CMAKE_CXX_FLAGS "-O3 -s -DNDEBUG ${CMAKE_CXX_FLAGS} -Wno-attributes") if(IS_IOS) set(CMAKE_CXX_FLAGS "-mfpu=neon -marm -fobjc-abi-version=2 -fobjc-arc \ -std=gnu++11 -stdlib=libc++ -isysroot ${CMAKE_OSX_SYSROOT} ${CMAKE_CXX_FLAGS}") diff --git a/src/common/log.h b/src/common/log.h index d574818f865ab6b2af748a5b3162b589f396a564..282ee2780993447051143866f65907ba7ce17be3 100644 --- a/src/common/log.h +++ b/src/common/log.h @@ -31,7 +31,8 @@ namespace paddle_mobile { #ifdef ANDROID -extern const char *ANDROID_LOG_TAG; +static const char *ANDROID_LOG_TAG = + "paddle_mobile LOG built on " __DATE__ " " __TIME__; #define ANDROIDLOGI(...) \ __android_log_print(ANDROID_LOG_INFO, ANDROID_LOG_TAG, __VA_ARGS__); \ diff --git a/src/common/type_define.h b/src/common/type_define.h index 389f9a715f8cec3f0b494ae3b43b3952e49677f8..a25a19f11f444c1a92b033c0f721174bda44ed50 100644 --- a/src/common/type_define.h +++ b/src/common/type_define.h @@ -37,8 +37,7 @@ template using OpCreator = std::function *( const std::string & /*type*/, const VariableNameMap & /*inputs*/, const VariableNameMap & /*outputs*/, - const framework::AttributeMap & /*attrs*/, - std::shared_ptr /*scope*/)>; + const framework::AttributeMap & /*attrs*/, framework::Scope * /*scope*/)>; using InferVarTypeFN = std::function; diff --git a/src/common/types.cpp b/src/common/types.cpp old mode 100755 new mode 100644 diff --git a/src/common/types.h b/src/common/types.h index e3b5e52218edb70186aec9452a96e6191ee30290..35c1659c5a246c49f6e0ccb31c3e63ce4fdd2e71 100755 --- a/src/common/types.h +++ b/src/common/types.h @@ -205,6 +205,8 @@ extern const char *G_OP_TYPE_FUSION_DECONV_ADD_BN_RELU; extern const char *G_OP_TYPE_FUSION_DECONV_ADD_BN; extern const char *G_OP_TYPE_FUSION_DECONV_BN_RELU; +extern const char *G_OP_TYPE_PAD2D; + extern std::unordered_map< std::string, std::pair, std::vector>> op_input_output_key; diff --git a/src/framework/attribute.h b/src/framework/attribute.h index 3bc9284bfa6873b9c3f97071b2f96677937c4206..e00cee09b36a4372c938f356900faab88e610010 100644 --- a/src/framework/attribute.h +++ b/src/framework/attribute.h @@ -91,7 +91,14 @@ class Attribute { break; } case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCK: { - attr.Set(attr_desc->block_idx); + break; + } + case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONGS: { + vector val(attr_desc->n_longs); + for (int i = 0; i < attr_desc->n_longs; ++i) { + val[i] = attr_desc->longs[i]; + } + attr.Set>(val); break; } default: @@ -139,6 +146,14 @@ class Attribute { return vistor(attr.variant_.Get>()); } else if (attr.variant_.TypeId() == typeid(int64_t).hash_code()) { return vistor(attr.variant_.Get()); + } else if (attr.variant_.TypeId() == + typeid(framework::BlockDesc *).hash_code()) { + return vistor(attr.variant_.Get()); + } else if (attr.variant_.TypeId() == + typeid(vector).hash_code()) { + return vistor(attr.variant_.Get>()); + } else if (attr.variant_.TypeId() == typeid(vector).hash_code()) { + return vistor(attr.variant_.Get>()); } else { PADDLE_MOBILE_THROW_EXCEPTION("type not support"); } @@ -146,7 +161,8 @@ class Attribute { private: Variant, vector, vector, bool, - vector, BlockDesc *, int64_t> + vector, BlockDesc *, vector, int64_t, + vector> variant_; }; diff --git a/src/framework/data_layout.h b/src/framework/data_layout.h index 665b5315bc1c0fca7b9e62f89062f375a9a011be..fd0bec39132e04cc0b5ef6b30ec48b106c79b534 100644 --- a/src/framework/data_layout.h +++ b/src/framework/data_layout.h @@ -42,6 +42,7 @@ inline DataLayout StringToDataLayout(const std::string &str) { } else { PADDLE_MOBILE_THROW_EXCEPTION("Unknown storage order string: %s", s.c_str()) } + return DataLayout::kNCHW; } inline std::string DataLayoutToString(const DataLayout &data_layout) { diff --git a/src/framework/dim.h b/src/framework/dim.h index 7c78659e3baacdf707dc46884c099dfd0cd284bb..e11d6fe39abf4b6f03c58a0b0ee1d4d3c442642b 100644 --- a/src/framework/dim.h +++ b/src/framework/dim.h @@ -82,6 +82,8 @@ struct Dim<0> { int64_t &operator[](int idx); int64_t operator[](int idx) const; + + int64_t head; }; namespace { @@ -131,6 +133,7 @@ int64_t &indexer(Dim &dim, int idx) { template <> int64_t &indexer<0>(Dim<0> &dim, int idx) { PADDLE_MOBILE_THROW_EXCEPTION("Invalid index") + return dim.head; } template @@ -147,6 +150,7 @@ int64_t indexer(const Dim &dim, int idx) { template <> int64_t indexer<0>(const Dim<0> &dim, int idx) { PADDLE_MOBILE_THROW_EXCEPTION("Invalid index") + return dim.head; } } // namespace diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 89253188951e8bc8268999c15298ceeb2ae0d69f..b5fab192aaed8ecb7796fc81b2ac67d810654c4c 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -57,32 +57,30 @@ Executor::Executor(const Program &program, PADDLE_MOBILE_ENFORCE(program_desc_ != nullptr, "program_desc_ should not be nullptr"); const auto &blocks = program_desc_->Blocks(); - ops_of_block_.resize(blocks.size()); - - for (int i = 0; i < blocks.size(); ++i) { - std::shared_ptr block_desc = blocks[i]; - std::vector> ops = block_desc->Ops(); - for (int j = 0; j < ops.size(); ++j) { - std::shared_ptr op_desc = ops[j]; - DLOG << "create op: " << op_desc->Type(); - - auto op_handler = OpRegistry::CreateOp( - op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(), - op_desc->GetAttrMap(), program_.scope); - // infer shape to reshape inputs and outputs before predict, - // but for lod mode, it still need to infer shape in runtime - if (!lod_mode) { - op_handler->InferShape(); - } - ops_of_block_[i].push_back(op_handler); + + std::shared_ptr block_desc = blocks[0]; + std::vector> ops = block_desc->Ops(); + for (int j = 0; j < ops.size(); ++j) { + std::shared_ptr op_desc = ops[j]; + DLOG << "create op: " << op_desc->Type(); + + auto op_handler = OpRegistry::CreateOp( + op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(), + op_desc->GetAttrMap(), program_.scope.get()); + // infer shape to reshape inputs and outputs before predict, + // but for lod mode, it still need to infer shape in runtime + if (!lod_mode) { + op_handler->InferShape(); } + ops_of_block0_.push_back(op_handler); } - if (program_.combined) { InitCombineMemory(); } else { InitMemory(); } + // resize feed and fetch list + InitFeedFetchList(); #ifdef PADDLE_MOBILE_FPGA program_.scope->EraseVars({"feed", "fetch"}); @@ -90,13 +88,37 @@ Executor::Executor(const Program &program, #endif int count = 0; - for (int block_id = 0; block_id < ops_of_block_.size(); ++block_id) { - for (auto &op_handler : ops_of_block_[block_id]) { - DLOG << "Initialize op[" << count++ << "]: " << op_handler->Type(); - op_handler->Init(); - ops_list_.push_back(op_handler); + for (auto &op_handler : ops_of_block0_) { + DLOG << "Initialize op[" << count++ << "]: " << op_handler->Type(); + op_handler->Init(); + } +} + +template +void Executor::InitFeedFetchList() { + std::unordered_map feed_indices, fetch_indices; + for (const auto &block : program_desc_->Blocks()) { + for (const auto &op_desc : block->Ops()) { + if (op_desc->Type() == "feed") { + std::string name = op_desc->Output("Out")[0]; + feed_indices[name] = op_desc->GetAttr("col").Get(); + } else if (op_desc->Type() == "fetch") { + std::string name = op_desc->Input("X")[0]; + fetch_indices[name] = op_desc->GetAttr("col").Get(); + } } } + feed_indices_.swap(feed_indices); + fetch_indices_.swap(fetch_indices); + + auto *feed_var = program_.scope->Var("feed"); + auto *feed_list = feed_var->template GetMutable(); + feed_list->resize(feed_indices_.size()); + + auto *fetch_var = program_.scope->Var("fetch"); + auto *fetch_list = + fetch_var->template GetMutable(); + fetch_list->resize(fetch_indices_.size()); } template @@ -181,20 +203,20 @@ void Executor::InitMemory() { for (const auto &block : program_desc_->Blocks()) { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); - auto tensor = var->template GetMutable(); if (var_desc->Persistable()) { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; } char *origin_data = ReadFileToBuff(program_.model_path + "/" + var_desc->Name()); char *data = origin_data; + auto tensor = var->template GetMutable(); LoadMemory(reinterpret_cast(&data), var_desc, tensor); delete[] origin_data; } else { - if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { - varInputMemory(var_desc, var, tensor); - } + DLOG << "init no persistable var: " << var_desc->Name(); + varInputMemory(var_desc, var); } } } @@ -216,23 +238,18 @@ void Executor::InitCombineMemory() { for (const auto &block : program_desc_->Blocks()) { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); - auto tensor = var->template GetMutable(); if (var_desc->Persistable()) { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; } DLOG << " init combine memory persistable: " << var_desc->Name(); - + auto tensor = var->template GetMutable(); LoadMemory(reinterpret_cast(&data), var_desc, tensor); } else { - if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { - DLOG << " init combine memory no persistable in lod: " - << var_desc->Name(); - varInputMemory(var_desc, var, tensor); - } else { - DLOG << " init combine memory no persistable: " << var_desc->Name(); - } + DLOG << " init combine memory no persistable: " << var_desc->Name(); + varInputMemory(var_desc, var); } } } @@ -250,6 +267,7 @@ void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { auto tensor = var->template GetMutable(); if (var_desc->Persistable()) { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; } } else { @@ -260,6 +278,9 @@ void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { input_tensor.dims()[3]}); tensor->Resize(new_dim); tensor->template mutable_data(); + } else { + PADDLE_MOBILE_THROW_EXCEPTION("Unsupported var type `%d`", + var_desc->Type()); } } } @@ -272,34 +293,44 @@ void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { template bool Executor::varInputMemory( - const std::shared_ptr &var_desc, Variable *var, - LoDTensor *tensor) const { + const std::shared_ptr &var_desc, Variable *var) const { #ifdef PADDLE_MOBILE_FPGA + framework::LoDTensor *tensor = var->template GetMutable(); tensor->init(typeid(float)); return true; #endif - auto type = var_desc->Tensor_desc().DataType(); - switch (type) { - case VARTYPE_TYPE_FP32: - tensor->mutable_data(); - break; - case VARTYPE_TYPE_INT8: - tensor->mutable_data(); - break; - case VARTYPE_TYPE_INT32: - tensor->mutable_data(); - break; - case VARTYPE_TYPE_INT64: - tensor->mutable_data(); - break; - default: - break; + auto TypeId = [](const VarType_Type &type) -> std::type_index { + switch (type) { + case VARTYPE_TYPE_BOOL: + return typeid(bool); + case VARTYPE_TYPE_FP32: + return typeid(float); + case VARTYPE_TYPE_INT8: + return typeid(int8_t); + case VARTYPE_TYPE_INT32: + return typeid(int); + case VARTYPE_TYPE_INT64: + return typeid(int64_t); + default: + PADDLE_MOBILE_THROW_EXCEPTION("got unhandled var type `%d`", type); + } + }; + + auto type = var_desc->Type(); + if (type == VARTYPE_TYPE_LOD_TENSOR) { + auto data_type = var_desc->Tensor_desc().DataType(); + framework::LoDTensor *tensor = var->template GetMutable(); + tensor->mutable_data(TypeId(data_type)); + } else if (type == VARTYPE_TYPE_STEP_SCOPES) { + std::vector *step_scopes = + var->template GetMutable>(); + } else if (type == VARTYPE_TYPE_STEP_LOD_TENSOR_ARRAY) { + framework::LoDTensorArray *tensor_array = + var->template GetMutable(); + } else { + PADDLE_MOBILE_THROW_EXCEPTION("got unhandled var type `%d`", type); } - bool is_mute_match = - (type == VARTYPE_TYPE_FP32) || (type == VARTYPE_TYPE_INT8) || - (type == VARTYPE_TYPE_INT32) || (type == VARTYPE_TYPE_INT64); - PADDLE_MOBILE_ENFORCE(is_mute_match, "got unhandled data type : %d", type); - return is_mute_match; + return true; } template @@ -323,11 +354,19 @@ PMStatus Executor::Predict( template std::vector Executor::Predict(const std::vector &input, const std::vector &dims) { + PADDLE_MOBILE_ENFORCE(feed_indices_.size() != 0, + "We don't know which tensor should be assign, since no " + "feed op found in this model"); + PADDLE_MOBILE_ENFORCE(fetch_indices_.size() != 0, + "We don't know which tensor should be fetch out, since " + "no fetch op found in this model"); + std::string input_name = feed_indices_.begin()->first; Tensor feed_tensor(input, make_ddim(dims)); - SetInput(feed_tensor, "feed"); + SetInput(feed_tensor, input_name); std::vector output; if (this->Predict() == PMSuccess) { - const auto output_tensor = GetOutput("fetch"); + std::string output_name = fetch_indices_.begin()->first; + const auto output_tensor = GetOutput(output_name); output.resize(output_tensor->numel()); memcpy(output.data(), output_tensor->template data(), output.size() * sizeof(T)); @@ -338,11 +377,13 @@ std::vector Executor::Predict(const std::vector &input, template void Executor::SetInput(const Tensor &input, const std::string &var_name) { - auto *target_var = program_.scope->FindVar(var_name); - PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", - var_name.c_str()); - - auto *target_tensor = target_var->template GetMutable(); + int index = 0; + if (feed_indices_.find(var_name) != feed_indices_.end()) { + index = feed_indices_.find(var_name)->second; + } + auto *feed_var = program_.scope->Var("feed"); + framework::LoDTensor &target = + feed_var->template GetMutable()->at(index); if (config_.load_when_predict) { if (input_dim_last_ != input.dims()) { @@ -351,68 +392,92 @@ void Executor::SetInput(const Tensor &input, } } - target_tensor->Resize(input.dims()); - target_tensor->ShareDataWith(input); + target.Resize(input.dims()); + target.ShareDataWith(input); } template void Executor::SetInput(const LoDTensor &input, const std::string &var_name) { - auto *target_var = program_.scope->FindVar(var_name); - PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", - var_name.c_str()); - auto *target_tensor = target_var->template GetMutable(); + int index = 0; + if (feed_indices_.find(var_name) != feed_indices_.end()) { + index = feed_indices_.find(var_name)->second; + } + auto *feed_var = program_.scope->Var("feed"); + framework::LoDTensor &target = + feed_var->template GetMutable()->at(index); if (config_.load_when_predict) { if (input_dim_last_ != input.dims()) { - InitNoPersistableMemory(*target_tensor); + InitNoPersistableMemory(input); input_dim_last_ = input.dims(); } } - target_tensor->Resize(input.dims()); - target_tensor->ShareDataWith(input); - target_tensor->set_lod(input.lod()); + target.Resize(input.dims()); + target.ShareDataWith(input); + target.set_lod(input.lod()); +} + +template +std::shared_ptr Executor::GetOutput( + const std::string &var_name) { + const auto &iter = fetch_indices_.find(var_name); + if (var_name == "fetch" || iter != fetch_indices_.end()) { + int index = 0; + if (iter != fetch_indices_.end()) { + index = iter->second; + } + auto *fetch_var = program_.scope->Var("fetch"); + framework::LoDTensor &target = + fetch_var->template GetMutable()->at(index); + + return std::make_shared(target); + } else { + auto *fetch_var = program_.scope->Var(var_name); + framework::LoDTensor *target = + fetch_var->template GetMutable(); + return std::make_shared(*target); + } } template PMStatus Executor::Predict() { #ifdef PADDLE_MOBILE_PROFILE - std::vector profile(ops_list_.size()); + std::vector profile(ops_of_block0_.size()); struct timespec ts; int op_index = 0; #endif - for (auto &block : ops_of_block_) { - for (auto &op_handler : block) { + for (auto &op_handler : ops_of_block0_) { #ifdef PADDLE_MOBILE_PROFILE - clock_gettime(CLOCK_MONOTONIC, &ts); - profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; + clock_gettime(CLOCK_MONOTONIC, &ts); + profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; #endif - if (lod_mode_) { - op_handler->InferShape(); - } - op_handler->Run(); + if (lod_mode_) { + op_handler->InferShape(); + } + op_handler->Run(); #ifdef PADDLE_MOBILE_PROFILE - clock_gettime(CLOCK_MONOTONIC, &ts); - profile[op_index].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; - ++op_index; + clock_gettime(CLOCK_MONOTONIC, &ts); + profile[op_index].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; + ++op_index; #endif - } } #ifdef PADDLE_MOBILE_PROFILE std::unordered_map _tp; for (int i = 0; i < profile.size(); i++) { const auto &pInfo = profile[i]; uint64_t timeCost = pInfo.runEnd - pInfo.runBegin; - if (ops_list_[i]->Type() == "conv2d" || - ops_list_[i]->Type() == "depthwise_conv2d") { - auto inputs = ops_list_[i]->Inputs(); + if (ops_of_block0_[i]->Type() == "conv2d" || + ops_of_block0_[i]->Type() == "depthwise_conv2d") { + auto inputs = ops_of_block0_[i]->Inputs(); auto *filter = GetVarValue("Filter", inputs, *(program_.scope)); int kernel_size = filter->dims()[2]; - _tp[ops_list_[i]->Type() + "_" + std::to_string(kernel_size)] += timeCost; + _tp[ops_of_block0_[i]->Type() + "_" + std::to_string(kernel_size)] += + timeCost; } else { - _tp[ops_list_[i]->Type()] += timeCost; + _tp[ops_of_block0_[i]->Type()] += timeCost; } } printf("====================[ profile ]======================\n"); @@ -437,16 +502,6 @@ PMStatus Executor::Predict() { return PMSuccess; } -template -std::shared_ptr Executor::GetOutput( - const std::string &var_name) { - auto *target_var = program_.scope->FindVar(var_name); - PADDLE_MOBILE_ENFORCE(target_var != nullptr, "Variable %s is not exist", - var_name.c_str()); - auto *output_tensor = target_var->template GetMutable(); - return std::make_shared(*output_tensor); -} - #ifdef PADDLE_MOBILE_FPGA template void Executor::InjectVariable(const Tensor &t, @@ -476,20 +531,6 @@ void Executor::FeedData(const std::vector &v) { } } -template -void Executor::FeedTensorData(const vector &v) { - auto input_size = v.size(); - int index = 0; - auto vars = program_.scope->VarContain("feed", &index); - PADDLE_MOBILE_ENFORCE(input_size == vars.size(), - "input data number not correct"); - for (int i = 0; i < input_size; i++) { - auto var = program_.scope->Var("feed", i + index); - auto feed_tensor = var->template GetMutable(); - feed_tensor->ShareDataWith(v[i]); - } -} - template void Executor::GetResults(std::vector *v) { auto output_size = v->size(); @@ -524,11 +565,11 @@ framework::Tensor *Executor::GetTensorByName( const std::string &name) { auto var = program_.scope->Var(name); return var->template GetMutable(); -}; +} template std::shared_ptr Executor::FetchResult(int id) { - auto &ops = ops_of_block_[0]; + auto &ops = ops_of_block0_; PADDLE_MOBILE_ENFORCE(id < (int)ops.size(), "Index out of range"); auto op = id < 0 ? ops[ops.size() - 1] : ops[id]; @@ -542,7 +583,7 @@ std::shared_ptr Executor::FetchResult(int id) { template void Executor::Predict_From_To(int start, int end) { - auto &ops = ops_of_block_[0]; + auto &ops = ops_of_block0_; end = end < 0 ? static_cast(ops.size()) : end; PADDLE_MOBILE_ENFORCE(start >= 0 && start < end && end <= ops.size(), "start or end parameter is wrong"); diff --git a/src/framework/executor.h b/src/framework/executor.h index 65484b11465089eb90a99ac21caa9e1c386d4d10..853914c54cb962c570ae2a9751500d3275091499 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -53,7 +53,6 @@ class Executor { void InjectVariable(const Tensor &t, std::string var_name); void FeedData(const Tensor &t); void FeedData(const std::vector &v); - void FeedTensorData(const std::vector &v); void GetResults(std::vector *v); void GetTensorResults(std::vector *v); @@ -68,8 +67,9 @@ class Executor { protected: Executor() = default; - bool varInputMemory(const std::shared_ptr &var_desc, Variable *var, - LoDTensor *tensor) const; + bool varInputMemory(const std::shared_ptr &var_desc, + Variable *var) const; + void InitFeedFetchList(); void InitMemory(); void InitCombineMemory(); void InitNoPersistableMemory(const Tensor &input_tensor); @@ -85,10 +85,9 @@ class Executor { PaddleMobileConfigInternal config_; Program program_; std::shared_ptr program_desc_; - typedef std::shared_ptr> OperatorBasePtr; - std::vector> ops_of_block_; - // operators list - std::vector ops_list_; + std::vector>> ops_of_block0_; + std::unordered_map feed_indices_; + std::unordered_map fetch_indices_; // for super resoltion DDim input_dim_last_; diff --git a/src/framework/framework.pb-c.c b/src/framework/framework.pb-c.c index bbccc76a22f5efbe69b58e6a546d063923077af6..394c17f09724a9db2bfc62603a3ffa46cf032899 100644 --- a/src/framework/framework.pb-c.c +++ b/src/framework/framework.pb-c.c @@ -13,13 +13,6 @@ void paddle_mobile__framework__proto__version__init( PADDLE_MOBILE__FRAMEWORK__PROTO__VERSION__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__version__get_packed_size( - const PaddleMobile__Framework__Proto__Version *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__version__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__Version * paddle_mobile__framework__proto__version__unpack(ProtobufCAllocator *allocator, size_t len, @@ -54,13 +47,6 @@ void paddle_mobile__framework__proto__op_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__OP_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__op_desc__get_packed_size( - const PaddleMobile__Framework__Proto__OpDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__op_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__OpDesc * paddle_mobile__framework__proto__op_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -95,13 +81,6 @@ void paddle_mobile__framework__proto__op_proto__init( PADDLE_MOBILE__FRAMEWORK__PROTO__OP_PROTO__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__op_proto__get_packed_size( - const PaddleMobile__Framework__Proto__OpProto *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__op_proto__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__OpProto * paddle_mobile__framework__proto__op_proto__unpack(ProtobufCAllocator *allocator, size_t len, @@ -162,13 +141,6 @@ void paddle_mobile__framework__proto__var_type__init( PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__var_type__get_packed_size( - const PaddleMobile__Framework__Proto__VarType *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__var_type__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__VarType * paddle_mobile__framework__proto__var_type__unpack(ProtobufCAllocator *allocator, size_t len, @@ -191,13 +163,6 @@ void paddle_mobile__framework__proto__var_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__var_desc__get_packed_size( - const PaddleMobile__Framework__Proto__VarDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__var_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__VarDesc * paddle_mobile__framework__proto__var_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -220,13 +185,6 @@ void paddle_mobile__framework__proto__block_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__BLOCK_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__block_desc__get_packed_size( - const PaddleMobile__Framework__Proto__BlockDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__block_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__BlockDesc * paddle_mobile__framework__proto__block_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data) { @@ -248,13 +206,6 @@ void paddle_mobile__framework__proto__program_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__PROGRAM_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__program_desc__get_packed_size( - const PaddleMobile__Framework__Proto__ProgramDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__program_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__ProgramDesc * paddle_mobile__framework__proto__program_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data) { @@ -310,7 +261,7 @@ const ProtobufCMessageDescriptor NULL /* reserved[123] */ }; static const ProtobufCFieldDescriptor - paddle_mobile__framework__proto__op_desc__attr__field_descriptors[13] = { + paddle_mobile__framework__proto__op_desc__attr__field_descriptors[14] = { { "name", 1, PROTOBUF_C_LABEL_REQUIRED, PROTOBUF_C_TYPE_STRING, 0, /* quantifier_offset */ @@ -405,6 +356,13 @@ static const ProtobufCFieldDescriptor NULL, NULL, 0, /* flags */ 0, NULL, NULL /* reserved1,reserved2, etc */ }, + { + "longs", 15, PROTOBUF_C_LABEL_REPEATED, PROTOBUF_C_TYPE_INT64, + offsetof(PaddleMobile__Framework__Proto__OpDesc__Attr, n_longs), + offsetof(PaddleMobile__Framework__Proto__OpDesc__Attr, longs), NULL, + NULL, 0, /* flags */ + 0, NULL, NULL /* reserved1,reserved2, etc */ + }, }; static const unsigned paddle_mobile__framework__proto__op_desc__attr__field_indices_by_name[] = { @@ -417,6 +375,7 @@ static const unsigned 2, /* field[2] = i */ 5, /* field[5] = ints */ 11, /* field[11] = l */ + 13, /* field[13] = longs */ 0, /* field[0] = name */ 4, /* field[4] = s */ 7, /* field[7] = strings */ @@ -424,7 +383,7 @@ static const unsigned }; static const ProtobufCIntRange paddle_mobile__framework__proto__op_desc__attr__number_ranges[2 + 1] = { - {1, 0}, {10, 8}, {0, 13}}; + {1, 0}, {10, 8}, {0, 14}}; const ProtobufCMessageDescriptor paddle_mobile__framework__proto__op_desc__attr__descriptor = { PROTOBUF_C__MESSAGE_DESCRIPTOR_MAGIC, @@ -433,7 +392,7 @@ const ProtobufCMessageDescriptor "PaddleMobile__Framework__Proto__OpDesc__Attr", "paddle_mobile.framework.proto", sizeof(PaddleMobile__Framework__Proto__OpDesc__Attr), - 13, + 14, paddle_mobile__framework__proto__op_desc__attr__field_descriptors, paddle_mobile__framework__proto__op_desc__attr__field_indices_by_name, 2, @@ -1448,7 +1407,7 @@ const ProtobufCMessageDescriptor NULL /* reserved[123] */ }; static const ProtobufCEnumValue - paddle_mobile__framework__proto__attr_type__enum_values_by_number[11] = { + paddle_mobile__framework__proto__attr_type__enum_values_by_number[12] = { {"INT", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__INT", 0}, {"FLOAT", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__FLOAT", 1}, {"STRING", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__STRING", 2}, @@ -1460,15 +1419,16 @@ static const ProtobufCEnumValue {"BLOCK", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCK", 8}, {"LONG", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONG", 9}, {"BLOCKS", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCKS", 10}, + {"LONGS", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONGS", 11}, }; static const ProtobufCIntRange paddle_mobile__framework__proto__attr_type__value_ranges[] = {{0, 0}, - {0, 11}}; + {0, 12}}; static const ProtobufCEnumValueIndex - paddle_mobile__framework__proto__attr_type__enum_values_by_name[11] = { + paddle_mobile__framework__proto__attr_type__enum_values_by_name[12] = { {"BLOCK", 8}, {"BLOCKS", 10}, {"BOOLEAN", 6}, {"BOOLEANS", 7}, {"FLOAT", 1}, {"FLOATS", 4}, {"INT", 0}, {"INTS", 3}, - {"LONG", 9}, {"STRING", 2}, {"STRINGS", 5}, + {"LONG", 9}, {"LONGS", 11}, {"STRING", 2}, {"STRINGS", 5}, }; const ProtobufCEnumDescriptor paddle_mobile__framework__proto__attr_type__descriptor = { @@ -1477,9 +1437,9 @@ const ProtobufCEnumDescriptor "AttrType", "PaddleMobile__Framework__Proto__AttrType", "paddle_mobile.framework.proto", - 11, + 12, paddle_mobile__framework__proto__attr_type__enum_values_by_number, - 11, + 12, paddle_mobile__framework__proto__attr_type__enum_values_by_name, 1, paddle_mobile__framework__proto__attr_type__value_ranges, diff --git a/src/framework/framework.pb-c.h b/src/framework/framework.pb-c.h index b7bac7ef9c99f62489bcd74936b3c0b55374abfb..a0f2eaee12acded26ce210c1016aeba0c4eba4ed 100644 --- a/src/framework/framework.pb-c.h +++ b/src/framework/framework.pb-c.h @@ -102,8 +102,9 @@ typedef enum _PaddleMobile__Framework__Proto__AttrType { PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BOOLEANS = 7, PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCK = 8, PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONG = 9, - PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCKS = - 10 PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE( + PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCKS = 10, + PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONGS = + 11 PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE( PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE) } PaddleMobile__Framework__Proto__AttrType; @@ -152,13 +153,15 @@ struct _PaddleMobile__Framework__Proto__OpDesc__Attr { int64_t l; size_t n_blocks_idx; int32_t *blocks_idx; + size_t n_longs; + int64_t *longs; }; #define PADDLE_MOBILE__FRAMEWORK__PROTO__OP_DESC__ATTR__INIT \ { \ PROTOBUF_C_MESSAGE_INIT( \ &paddle_mobile__framework__proto__op_desc__attr__descriptor) \ , NULL, PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__INT, 0, 0, 0, 0, NULL, \ - 0, NULL, 0, NULL, 0, NULL, 0, 0, 0, NULL, 0, 0, 0, 0, 0, NULL \ + 0, NULL, 0, NULL, 0, NULL, 0, 0, 0, NULL, 0, 0, 0, 0, 0, NULL, 0, NULL \ } struct _PaddleMobile__Framework__Proto__OpDesc__Var { @@ -417,8 +420,6 @@ struct _PaddleMobile__Framework__Proto__ProgramDesc { /* PaddleMobile__Framework__Proto__Version methods */ void paddle_mobile__framework__proto__version__init( PaddleMobile__Framework__Proto__Version *message); -size_t paddle_mobile__framework__proto__version__get_packed_size( - const PaddleMobile__Framework__Proto__Version *message); PaddleMobile__Framework__Proto__Version * paddle_mobile__framework__proto__version__unpack(ProtobufCAllocator *allocator, size_t len, @@ -435,8 +436,6 @@ void paddle_mobile__framework__proto__op_desc__var__init( /* PaddleMobile__Framework__Proto__OpDesc methods */ void paddle_mobile__framework__proto__op_desc__init( PaddleMobile__Framework__Proto__OpDesc *message); -size_t paddle_mobile__framework__proto__op_desc__get_packed_size( - const PaddleMobile__Framework__Proto__OpDesc *message); PaddleMobile__Framework__Proto__OpDesc * paddle_mobile__framework__proto__op_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -453,8 +452,6 @@ void paddle_mobile__framework__proto__op_proto__attr__init( /* PaddleMobile__Framework__Proto__OpProto methods */ void paddle_mobile__framework__proto__op_proto__init( PaddleMobile__Framework__Proto__OpProto *message); -size_t paddle_mobile__framework__proto__op_proto__get_packed_size( - const PaddleMobile__Framework__Proto__OpProto *message); PaddleMobile__Framework__Proto__OpProto * paddle_mobile__framework__proto__op_proto__unpack(ProtobufCAllocator *allocator, size_t len, @@ -483,8 +480,6 @@ void paddle_mobile__framework__proto__var_type__tuple__init( /* PaddleMobile__Framework__Proto__VarType methods */ void paddle_mobile__framework__proto__var_type__init( PaddleMobile__Framework__Proto__VarType *message); -size_t paddle_mobile__framework__proto__var_type__get_packed_size( - const PaddleMobile__Framework__Proto__VarType *message); PaddleMobile__Framework__Proto__VarType * paddle_mobile__framework__proto__var_type__unpack(ProtobufCAllocator *allocator, size_t len, @@ -495,8 +490,6 @@ void paddle_mobile__framework__proto__var_type__free_unpacked( /* PaddleMobile__Framework__Proto__VarDesc methods */ void paddle_mobile__framework__proto__var_desc__init( PaddleMobile__Framework__Proto__VarDesc *message); -size_t paddle_mobile__framework__proto__var_desc__get_packed_size( - const PaddleMobile__Framework__Proto__VarDesc *message); PaddleMobile__Framework__Proto__VarDesc * paddle_mobile__framework__proto__var_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -507,8 +500,6 @@ void paddle_mobile__framework__proto__var_desc__free_unpacked( /* PaddleMobile__Framework__Proto__BlockDesc methods */ void paddle_mobile__framework__proto__block_desc__init( PaddleMobile__Framework__Proto__BlockDesc *message); -size_t paddle_mobile__framework__proto__block_desc__get_packed_size( - const PaddleMobile__Framework__Proto__BlockDesc *message); PaddleMobile__Framework__Proto__BlockDesc * paddle_mobile__framework__proto__block_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data); @@ -518,8 +509,6 @@ void paddle_mobile__framework__proto__block_desc__free_unpacked( /* PaddleMobile__Framework__Proto__ProgramDesc methods */ void paddle_mobile__framework__proto__program_desc__init( PaddleMobile__Framework__Proto__ProgramDesc *message); -size_t paddle_mobile__framework__proto__program_desc__get_packed_size( - const PaddleMobile__Framework__Proto__ProgramDesc *message); PaddleMobile__Framework__Proto__ProgramDesc * paddle_mobile__framework__proto__program_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data); diff --git a/src/framework/framework.proto b/src/framework/framework.proto index 4f41e26dc2df8550a6ce318d6e39ef4f3e875e73..27a98e0d6178b0fb20dcf640635413691efb7f10 100644 --- a/src/framework/framework.proto +++ b/src/framework/framework.proto @@ -35,6 +35,7 @@ enum AttrType { BLOCK = 8; LONG = 9; BLOCKS = 10; + LONGS = 11; } // OpDesc describes an instance of a C++ framework::OperatorBase @@ -55,6 +56,7 @@ message OpDesc { optional int32 block_idx = 12; optional int64 l = 13; repeated int32 blocks_idx = 14; + repeated int64 longs = 15; }; message Var { diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 3b214a52f1314202d183b871784bebae0b6ec795..e72c55f5f736b81362f461952a706127998f9ade 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -125,10 +125,6 @@ LOAD_OP1(prior_box, CPU); LOAD_OP2(fusion_conv_add_relu, CPU, FPGA); LOAD_FUSION_MATCHER(fusion_conv_add_relu); #endif -#ifdef FUSION_CONVADDADDPRELU_OP -LOAD_OP2(fusion_conv_add_add_prelu, CPU, FPGA); -LOAD_FUSION_MATCHER(fusion_conv_add_add_prelu); -#endif #ifdef FUSION_CONVADD_OP LOAD_OP2(fusion_conv_add, CPU, MALI_GPU); LOAD_FUSION_MATCHER(fusion_conv_add); @@ -178,10 +174,6 @@ LOAD_FUSION_MATCHER(fusion_conv_add_bn); #ifdef DROPOUT_OP LOAD_OP2(dropout, CPU, FPGA); #endif -#ifdef FUSION_CONVADDPRELU_OP -LOAD_OP2(fusion_conv_add_prelu, CPU, FPGA); -LOAD_FUSION_MATCHER(fusion_conv_add_prelu); -#endif #ifdef FUSION_DWCONVBNRELU_OP LOAD_OP1(fusion_dwconv_bn_relu, CPU); LOAD_FUSION_MATCHER(fusion_dwconv_bn_relu); @@ -324,3 +316,15 @@ LOAD_OP1(psroi_pool, CPU); #ifdef ROI_PERSPECTIVE_OP LOAD_OP1(roi_perspective_transform, CPU); #endif +#ifdef BEAM_SEARCH_OP +LOAD_OP1(beam_search, CPU); +#endif +#ifdef BEAM_SEARCH_DECODE_OP +LOAD_OP1(beam_search_decode, CPU); +#endif +#ifdef PAD2D_OP +LOAD_OP1(pad2d, CPU); +#endif +#ifdef ONE_HOT_OP +LOAD_OP1(one_hot, CPU); +#endif diff --git a/src/framework/lod_tensor.h b/src/framework/lod_tensor.h index 8c35a48f5fa2f13aeb2474eaf6e369dab09ee6d7..e96fe0e501f06814914677d0fb6248a11fa1dde5 100644 --- a/src/framework/lod_tensor.h +++ b/src/framework/lod_tensor.h @@ -221,6 +221,8 @@ inline Print &operator<<(Print &printer, const LoDTensor &tensor) { printer << static_cast(tensor.data()[i]) << " "; } else if (tensor.type() == typeid(int32_t)) { printer << tensor.data()[i] << " "; + } else if (tensor.type() == typeid(bool)) { + printer << tensor.data()[i] << " "; } } #endif // PADDLE_MOBILE_FPGA diff --git a/src/framework/op_registry.h b/src/framework/op_registry.h index 219385ab1429fefddc9d380799259f7562e0030f..f57519ee0272d74507ad9e53864262310adced4d 100644 --- a/src/framework/op_registry.h +++ b/src/framework/op_registry.h @@ -58,8 +58,7 @@ struct OpInfoFiller { void operator()(const std::string& op_type, OpInfo* info) const { info->creator_ = [](const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, - const AttributeMap& attrs, - std::shared_ptr scope) { + const AttributeMap& attrs, framework::Scope* scope) { return new T(type, inputs, outputs, attrs, scope); }; } @@ -91,7 +90,7 @@ class OpRegistry { static std::shared_ptr> CreateOp( const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap attrs, - std::shared_ptr scope) { + paddle_mobile::framework::Scope* scope) { auto& info = OpInfoMap::Instance()->Get(type); auto op = info.Creator()(type, inputs, outputs, attrs, scope); return std::shared_ptr>(op); diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index 49507dc75dbcbd1bac9385ed6fab14b694c8f7be..74398bbc5b368236d56e5180452b5b05d7d156ad 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -43,7 +43,7 @@ OperatorBase::OperatorBase(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : type_(type), inputs_(inputs), outputs_(outputs), @@ -67,30 +67,22 @@ void OperatorBase::Run() { for (const auto key : input_keys) { auto var_vec_in = inputs_.at(key); for (int i = 0; i < var_vec_in.size(); ++i) { - auto vari = this->scope_->FindVar(var_vec_in[i]); - if (vari->IsInitialized()) { - const Tensor *tensor = vari->template Get(); - if (tensor) { - DLOG << type_ << " input- " << key << "=" << *tensor; -#ifdef PADDLE_MOBILE_FPGA - DLOG << var_vec_in[i]; -#endif - } + auto var = this->scope_->FindVar(var_vec_in[i]); + if (var->IsInitialized() && + var->template IsType()) { + const Tensor *tensor = var->template Get(); + if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor; } } } for (const auto key : GetOutKeys()) { auto var_vec_out = outputs_.at(key); for (int i = 0; i < var_vec_out.size(); ++i) { - auto vari = scope_->FindVar(var_vec_out[i]); - if (vari->IsInitialized()) { - const Tensor *tensor = vari->template Get(); - if (tensor) { - DLOG << type_ << " output- " << key << "=" << *tensor; -#ifdef PADDLE_MOBILE_FPGA - DLOG << var_vec_out[i]; -#endif - } + auto var = scope_->FindVar(var_vec_out[i]); + if (var->IsInitialized() && + var->template IsType()) { + const Tensor *tensor = var->template Get(); + if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor; } } } diff --git a/src/framework/operator.h b/src/framework/operator.h index 7a3f020638b775f36d8d716eacb729e0437b949c..aaddb9c5649dca1c55daec0497354b127a118605 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include -#include #include #include #include @@ -58,7 +57,7 @@ class OperatorBase { public: OperatorBase(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, - std::shared_ptr scope); + framework::Scope *scope); virtual ~OperatorBase() {} virtual void Init() = 0; @@ -81,11 +80,10 @@ class OperatorBase { } #ifdef PADDLE_MOBILE_FPGA void InsertTensors(); - void ChangeNameMap(string key, std::vector value); #endif protected: - std::shared_ptr scope_; + framework::Scope *scope_; std::string type_; VariableNameMap inputs_; VariableNameMap outputs_; @@ -98,35 +96,15 @@ class OperatorBase { template class OperatorWithKernel : public OperatorBase { public: -#ifndef PADDLE_MOBILE_FPGA1 OperatorWithKernel(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : OperatorBase(type, inputs, outputs, attrs, scope), - param_(inputs, outputs, attrs, scope.get()) { + param_(inputs, outputs, attrs, scope) { #ifdef PADDLE_MOBILE_CL kernel_.InitCLHelper(scope->GetCLScpoe()); #endif } -#else - OperatorWithKernel(const std::string &type, const VariableNameMap inputs, - const VariableNameMap &outputs, const AttributeMap &attrs, - std::shared_ptr scope) - : OperatorBase(type, inputs, outputs, attrs, scope) { - static int feed_num = 0; - static int fetch_num = 0; - if (type == "feed") { - auto new_name = string("feed") + std::to_string(feed_num++); - auto var = scope->Var(new_name); - (const_cast(inputs)).at("X") = {string(new_name)}; - } else if (type == "fetch") { - auto new_name = string("fetch") + std::to_string(fetch_num++); - auto var = scope->Var(new_name); - (const_cast(outputs)).at("Out") = {string(new_name)}; - } - param_ = ParamType(inputs, outputs, attrs, *scope); - } -#endif virtual void RunImpl() { this->kernel_.Compute(this->param_); } virtual void InferShape() const = 0; @@ -198,21 +176,20 @@ class FusionOpMatcher { std::shared_ptr new_opdesc_; }; -#define DECLARE_OPERATOR(OpName, OpParam, OpKernel) \ - template \ - class OpName##Op : public framework::OperatorWithKernel< \ - DeviceType, OpParam, \ - operators::OpKernel> { \ - public: \ - OpName##Op(const std::string &type, const VariableNameMap &inputs, \ - const VariableNameMap &outputs, \ - const framework::AttributeMap &attrs, \ - std::shared_ptr scope) \ - : framework::OperatorWithKernel, \ - operators::OpKernel>( \ - type, inputs, outputs, attrs, scope) {} \ - \ - void InferShape() const override; \ +#define DECLARE_OPERATOR(OpName, OpParam, OpKernel) \ + template \ + class OpName##Op : public framework::OperatorWithKernel< \ + DeviceType, OpParam, \ + operators::OpKernel> { \ + public: \ + OpName##Op(const std::string &type, const VariableNameMap &inputs, \ + const VariableNameMap &outputs, \ + const framework::AttributeMap &attrs, framework::Scope *scope) \ + : framework::OperatorWithKernel, \ + operators::OpKernel>( \ + type, inputs, outputs, attrs, scope) {} \ + \ + void InferShape() const override; \ }; #define DECLARE_KERNEL(OpName, OpParam) \ @@ -228,7 +205,7 @@ class FusionOpMatcher { cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \ const ::paddle_mobile::VariableNameMap &outputs, \ const ::paddle_mobile::framework::AttributeMap &attrs, \ - std::shared_ptr<::paddle_mobile::framework::Scope> scope) \ + ::paddle_mobile::framework::Scope *scope) \ : parent_cls(type, inputs, outputs, attrs, scope) {} } // namespace framework diff --git a/src/framework/program/op_desc.cpp b/src/framework/program/op_desc.cpp index 318dd643bbc1d34cf1ab9bc97218caddd8c8d528..ba3105778eddc553b43bd30a1da020666c09f994 100644 --- a/src/framework/program/op_desc.cpp +++ b/src/framework/program/op_desc.cpp @@ -42,9 +42,15 @@ OpDesc::OpDesc(PaddleMobile__Framework__Proto__OpDesc *desc) { PaddleMobile__Framework__Proto__OpDesc__Attr *attr = desc->attrs[k]; std::string attr_name(attr->name); attrs_[attr_name] = Attribute::GetAttrValue(attr); + proto_attrs_.push_back(*attr); } } +const std::vector + &OpDesc::GetProtoAttr() const { + return proto_attrs_; +} + const std::vector &OpDesc::Input(const std::string &name) const { return inputs_.find(name)->second; } @@ -58,6 +64,15 @@ Attribute OpDesc::GetAttr(const std::string &name) const { return it->second; } +void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) { + this->attrs_[name].Set(block); +} + +void OpDesc::SetBlocksAttr(const std::string &name, + std::vector blocks) { + this->attrs_[name].Set>(blocks); +} + std::unordered_map &OpDesc::GetAttrMap() { return attrs_; } diff --git a/src/framework/program/op_desc.h b/src/framework/program/op_desc.h index 4fdfac253f0525b288983e8bcf9c1b4eff8f393d..f9579df034ccfb3e68863b27eef1df935ba266d5 100644 --- a/src/framework/program/op_desc.h +++ b/src/framework/program/op_desc.h @@ -29,11 +29,13 @@ class OpDesc { friend class ProgramOptimize; friend class FusionOpMatcher; friend class Node; + explicit OpDesc(PaddleMobile__Framework__Proto__OpDesc *op_desc); OpDesc(const OpDesc &op_desc) : type_(op_desc.type_) { this->inputs_ = op_desc.inputs_; this->outputs_ = op_desc.outputs_; this->attrs_ = op_desc.attrs_; + this->proto_attrs_ = op_desc.proto_attrs_; } OpDesc() {} @@ -41,6 +43,12 @@ class OpDesc { const std::vector &Output(const std::string &name) const; Attribute GetAttr(const std::string &name) const; + const std::vector + &GetProtoAttr() const; + + void SetBlockAttr(const std::string &name, BlockDesc *block); + void SetBlocksAttr(const std::string &name, std::vector block); + VariableNameMap &GetInputs() { return inputs_; } VariableNameMap &GetOutputs() { return outputs_; } @@ -60,6 +68,7 @@ class OpDesc { VariableNameMap inputs_; VariableNameMap outputs_; AttributeMap attrs_; + std::vector proto_attrs_; }; Print &operator<<(Print &printer, const OpDesc &op_desc); diff --git a/src/framework/program/program_desc.cpp b/src/framework/program/program_desc.cpp index 6c203865a564c4a02bbd5a037de745bb67d5310a..b66c7a0dcf97ef8517e1122d2834aa992736c6e7 100644 --- a/src/framework/program/program_desc.cpp +++ b/src/framework/program/program_desc.cpp @@ -15,8 +15,8 @@ limitations under the License. */ #include #include +#include "framework/program/program_desc.h" #include "framework/program/tensor_desc.h" -#include "program_desc.h" namespace paddle_mobile { namespace framework { @@ -25,6 +25,25 @@ ProgramDesc::ProgramDesc(PaddleMobile__Framework__Proto__ProgramDesc *desc) { for (int i = 0; i < desc->n_blocks; ++i) { blocks_.emplace_back(std::make_shared(desc->blocks[i])); } + for (auto &block : blocks_) { + for (auto op : block->Ops()) { + for (const auto &attr : op->GetProtoAttr()) { + if (attr.type == PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCK) { + size_t blk_idx = attr.block_idx; + op->SetBlockAttr(attr.name, this->MutableBlock(blk_idx)); + } else if (attr.type == + PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCKS) { + size_t n_blocks_idx = attr.n_blocks_idx; + int32_t *blks_idx = attr.blocks_idx; + std::vector block_descs; + for (size_t i = 0; i < n_blocks_idx; ++i) { + block_descs.push_back(this->MutableBlock(blks_idx[i])); + } + op->SetBlocksAttr(attr.name, block_descs); + } + } + } + } } void ProgramDesc::Description(std::string header) { @@ -60,9 +79,8 @@ void ProgramDesc::Description(std::string header) { } for (const auto &var_desc : block->Vars()) { + LOG(kLOG_DEBUG1) << "var name: " << var_desc->Name(); if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { - LOG(kLOG_DEBUG1) << "var name: " << var_desc->Name(); - const TensorDesc &tensor_desc = var_desc->Tensor_desc(); LOG(kLOG_DEBUG2) << "in var tensor desc dims size: " diff --git a/src/framework/program/program_desc.h b/src/framework/program/program_desc.h index 5c87f565e13df1564343b43150a5696c3adaca39..5c75c915223d0768120b4153c38a3772ba74d8e9 100644 --- a/src/framework/program/program_desc.h +++ b/src/framework/program/program_desc.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "common/types.h" @@ -31,6 +32,14 @@ class ProgramDesc { std::shared_ptr Block(size_t idx); + BlockDesc *MutableBlock(size_t idx) { + if (idx == -1) { + return nullptr; + } else { + return blocks_[idx].get(); + } + } + const std::vector> &Blocks() { return blocks_; } ProgramDesc(const ProgramDesc &program_desc) { for (auto &block : program_desc.blocks_) { diff --git a/src/framework/scope.h b/src/framework/scope.h index 2bf9ed85bd36013814a6adb80884693d7a944c90..4193db30e4bb487a323a188a95e4e8bf156549d9 100644 --- a/src/framework/scope.h +++ b/src/framework/scope.h @@ -32,15 +32,7 @@ class Scope { Scope() = default; ~Scope() { - for (auto &var : vars_) { - delete var.second; - } - vars_.clear(); - for (auto kid : kids_) { - delete kid; - } - kids_.clear(); - + DropKids(); #ifdef PADDLE_MOBILE_CL delete cl_scope_; #endif diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 8b633ec5cca6719dc3b1ebf5637ca8796e90046f..24f09662ea5ecca2a96ccdac7e863034f6a3a311 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -209,8 +209,9 @@ class Tensor : public TensorBase { } inline void set_type(std::type_index type) { holder_->set_type(type); } inline void *get_data() { - return (void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); - } // NOLINT + return ( + void *)(((PlaceholderImpl *)(holder_.get()))->ptr_.get()); // NOLINT + } inline void *init(std::type_index type) { if (holder_ != nullptr) { diff --git a/src/framework/tensor_util.h b/src/framework/tensor_util.h index f888049b395e48b9d10cea731b092c899952e3d8..31fc5148c7c08bb0bb01ea19f7eaa97d2eb02123 100644 --- a/src/framework/tensor_util.h +++ b/src/framework/tensor_util.h @@ -14,13 +14,26 @@ limitations under the License. */ #pragma once #include +#include "framework/tensor.h" #include "memory/t_malloc.h" -#include "tensor.h" namespace paddle_mobile { namespace framework { -void TensorCopy(const Tensor &src, Tensor *dst); +void TensorCopy(const Tensor& src, Tensor* dst); + +template +void TensorFromVector(const std::vector& src, Tensor* dst); + +template +void TensorFromVector(const std::vector& src, Tensor* dst) { + auto src_ptr = static_cast(src.data()); + dst->Resize({static_cast(src.size())}); + auto dst_ptr = static_cast(dst->mutable_data()); + auto size = src.size() * sizeof(T); + + memory::Copy(dst_ptr, src_ptr, size); +} } // namespace framework } // namespace paddle_mobile diff --git a/src/io/api_paddle_mobile.cc b/src/io/api_paddle_mobile.cc index b7a8ebba9c1d9adac3993d22aa33fb5445235349..5839a279cdfc03472628cf7650b30064281a226e 100644 --- a/src/io/api_paddle_mobile.cc +++ b/src/io/api_paddle_mobile.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "io/api_paddle_mobile.h" +#include #include #include "common/enforce.h" #include "framework/tensor.h" @@ -145,7 +146,7 @@ void PaddleMobilePredictor::FeedPaddleTensors( tensors[i].init(typeid(float)); ConvertPaddleTensors(inputs[i], &tensors[i]); } - paddle_mobile_->FeedTensorData(tensors); + // paddle_mobile_->FeedTensorData(tensors); } template @@ -169,7 +170,7 @@ void PaddleMobilePredictor::GetPaddleTensor(const std::string &name, PaddleTensor *output) { framework::Tensor *t = paddle_mobile_->GetTensorByName(name); ConvertTensors(*t, output); -}; +} template void PaddleMobilePredictor::Predict_From_To(int start, int end) { diff --git a/src/io/api_paddle_mobile.h b/src/io/api_paddle_mobile.h index aa0050ca057baa94c457dc0eb313b2bd7f8355be..38af541a9262ea1f4c9ea0f8e4229316c54a4a18 100644 --- a/src/io/api_paddle_mobile.h +++ b/src/io/api_paddle_mobile.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include "common/types.h" #include "io/paddle_inference_api.h" diff --git a/src/io/jni/paddle_mobile_jni.cpp b/src/io/jni/paddle_mobile_jni.cpp index 12c0a6cbca1721578efe175d8c108e30de18be7d..63511a2226e9563e758f87fea4fed67438eda8f6 100644 --- a/src/io/jni/paddle_mobile_jni.cpp +++ b/src/io/jni/paddle_mobile_jni.cpp @@ -39,8 +39,6 @@ using framework::Tensor; using paddle_mobile::CPU; using std::string; -const char *ANDROID_LOG_TAG = - "paddle_mobile LOG built on " __DATE__ " " __TIME__; paddle_mobile::PaddleMobile paddle_mobile; static std::mutex shared_mutex; diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 6027f2d3c7396bfa33fd7674abc379390f784586..6294f6bf467b1c1684d87c51b9a3b04508d56016 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -152,14 +152,14 @@ PMStatus PaddleMobile::Predict() { } template -void PaddleMobile::Feed(const framework::Tensor &input, - const std::string &var_name) { +void PaddleMobile::Feed(const std::string &var_name, + const framework::Tensor &input) { executor_->SetInput(input, var_name); } template -void PaddleMobile::Feed(const framework::LoDTensor &input, - const std::string &var_name) { +void PaddleMobile::Feed(const std::string &var_name, + const framework::LoDTensor &input) { executor_->SetInput(input, var_name); } @@ -227,16 +227,11 @@ template void PaddleMobile::FeedData(const framework::Tensor &t) { executor_->FeedData(t); } + template void PaddleMobile::FeedData(const std::vector &v) { executor_->FeedData(v); -}; - -template -void PaddleMobile::FeedTensorData( - const std::vector &v) { - executor_->FeedTensorData(v); -}; +} template void PaddleMobile::GetResults(std::vector *v) { @@ -253,7 +248,7 @@ template framework::Tensor *PaddleMobile::GetTensorByName( const std::string &name) { return executor_->GetTensorByName(name); -}; +} template std::shared_ptr PaddleMobile::FetchResult( diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index d835eed2daad2a78e3c0e430cb73a8fbc078b667..7983541a221fb63f573dfa8186599934cd97387b 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -33,7 +33,7 @@ namespace paddle_mobile { template class PaddleMobile { public: - PaddleMobile(PaddleMobileConfigInternal config) : config_(config) { + explicit PaddleMobile(PaddleMobileConfigInternal config) : config_(config) { #ifndef PADDLE_MOBILE_CL bool is_gpu = std::is_same, Device>::value; PADDLE_MOBILE_ENFORCE(!is_gpu, "Please recompile with GPU_CL is on"); @@ -69,8 +69,8 @@ class PaddleMobile { const std::vector &dims); PMStatus Predict(); - void Feed(const framework::LoDTensor &input, const std::string &var_name); - void Feed(const framework::Tensor &input, const std::string &var_name); + void Feed(const std::string &var_name, const framework::LoDTensor &input); + void Feed(const std::string &var_name, const framework::Tensor &input); typedef std::shared_ptr LoDTensorPtr; LoDTensorPtr Fetch(const std::string &var_name); @@ -91,7 +91,6 @@ class PaddleMobile { void InjectVariable(const framework::Tensor &t, std::string var_name); void FeedData(const framework::Tensor &t); void FeedData(const std::vector &v); - void FeedTensorData(const std::vector &v); void GetResults(std::vector *v); void GetTensorResults(std::vector *v); diff --git a/src/operators/activation_op.cpp b/src/operators/activation_op.cpp index d98fcb92297dd485fc3e59cfe592e00672f4ffca..158eb8eb47e872ed3c90fd4ae3ea1a9d257333e6 100644 --- a/src/operators/activation_op.cpp +++ b/src/operators/activation_op.cpp @@ -17,11 +17,12 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -#define DEFINE_ACTIVATION_INFERSHAPE(OpName) \ - template \ - void OpName##Op::InferShape() const { \ - const auto &input_dims = this->param_.InputX()->dims(); \ - this->param_.Out()->Resize(input_dims); \ +#define DEFINE_ACTIVATION_INFERSHAPE(OpName) \ + template \ + void OpName##Op::InferShape() const { \ + const auto &input_dims = this->param_.InputX()->dims(); \ + this->param_.Out()->Resize(input_dims); \ + this->param_.Out()->set_lod(this->param_.InputX()->lod()); \ } #ifdef RELU_OP diff --git a/src/operators/batchnorm_op.h b/src/operators/batchnorm_op.h index a6df70c9356c9bdb8b1fe3ef4520f26ce911490a..ed46c8657f4a505364f41f4a9695bf9f4cb57fc9 100644 --- a/src/operators/batchnorm_op.h +++ b/src/operators/batchnorm_op.h @@ -32,8 +32,7 @@ class BatchNormOp public: BatchNormOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, BatchNormKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/kernel/arm/conv_add_kernel.cpp b/src/operators/beam_search_decode_op.cpp similarity index 64% rename from src/operators/kernel/arm/conv_add_kernel.cpp rename to src/operators/beam_search_decode_op.cpp index e016b8efbd15472ae0d77423d84dc19671bfa316..9b01d2e17f363d3b729102a9747f6dc6682ea8aa 100644 --- a/src/operators/kernel/arm/conv_add_kernel.cpp +++ b/src/operators/beam_search_decode_op.cpp @@ -11,27 +11,26 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef FUSION_CONVADD_OP -#include "operators/kernel/conv_add_kernel.h" -#include "../central-arm-func/conv_add_arm_func.h" +#ifdef BEAM_SEARCH_DECODE_OP -namespace paddle_mobile { -namespace operators { +#pragma once -template <> -bool ConvAddKernel::Init(FusionConvAddParam *param) { - return true; -} +#include "operators/beam_search_decode_op.h" -template <> -void ConvAddKernel::Compute(const FusionConvAddParam ¶m) { - ConvAddCompute(param); -} +namespace paddle_mobile { +namespace operators { -template class ConvAddKernel; +template +void BeamSearchDecodeOp::InferShape() const {} } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(beam_search_decode, ops::BeamSearchDecodeOp); #endif + +#endif // BEAM_SEARCH_DECODE_OP diff --git a/src/operators/kernel/arm/conv_add_prelu_kernel.cpp b/src/operators/beam_search_decode_op.h similarity index 60% rename from src/operators/kernel/arm/conv_add_prelu_kernel.cpp rename to src/operators/beam_search_decode_op.h index f04a9a7d746f2d970196945707bd05409c5fa340..f212959474eade3da0f026bcdb1e3d15ddd30c6d 100644 --- a/src/operators/kernel/arm/conv_add_prelu_kernel.cpp +++ b/src/operators/beam_search_decode_op.h @@ -12,27 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef FUSION_CONVADDPRELU_OP +#ifdef BEAM_SEARCH_DECODE_OP -#include "operators/kernel/conv_add_prelu_kernel.h" -#include "operators/kernel/central-arm-func/conv_add_prelu_arm_func.h" +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/beam_search_decode_kernel.h" namespace paddle_mobile { namespace operators { -template <> -bool ConvAddPReluKernel::Init(FusionConvAddPReluParam *param) { - return true; -} - -template <> -void ConvAddPReluKernel::Compute( - const FusionConvAddPReluParam ¶m) { - ConvAddPReluCompute(param); -} -template class ConvAddPReluKernel; +DECLARE_OPERATOR(BeamSearchDecode, BeamSearchDecodeParam, + BeamSearchDecodeKernel); } // namespace operators } // namespace paddle_mobile -#endif +#endif // BEAM_SEARCH_DECODE_OP diff --git a/src/operators/kernel/arm/conv_add_relu_kernel.cpp b/src/operators/beam_search_op.cpp similarity index 61% rename from src/operators/kernel/arm/conv_add_relu_kernel.cpp rename to src/operators/beam_search_op.cpp index e318a866a345d9d938471b262e78e2dc30153c40..502510ebeefd29336531fac24d279e009f6b8d6d 100644 --- a/src/operators/kernel/arm/conv_add_relu_kernel.cpp +++ b/src/operators/beam_search_op.cpp @@ -12,27 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef FUSION_CONVADDRELU_OP +#ifdef BEAM_SEARCH_OP -#include "operators/kernel/conv_add_relu_kernel.h" -#include "operators/kernel/central-arm-func/conv_add_relu_arm_func.h" +#pragma once + +#include "operators/beam_search_op.h" namespace paddle_mobile { namespace operators { -template <> -bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { - return true; -} - -template <> -void ConvAddReluKernel::Compute( - const FusionConvAddReluParam ¶m) { - ConvAddReluCompute(param); -} -template class ConvAddReluKernel; +template +void BeamSearchOp::InferShape() const {} } // namespace operators } // namespace paddle_mobile +namespace ops = paddle_mobile::operators; + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(beam_search, ops::BeamSearchOp); #endif + +#endif // BEAM_SEARCH_OP diff --git a/src/operators/beam_search_op.h b/src/operators/beam_search_op.h new file mode 100644 index 0000000000000000000000000000000000000000..985552d9f6efde5a474ca57672b8500bfc558e32 --- /dev/null +++ b/src/operators/beam_search_op.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/beam_search_kernel.h" + +namespace paddle_mobile { +namespace operators { + +DECLARE_OPERATOR(BeamSearch, BeamSearchParam, BeamSearchKernel); + +} // namespace operators +} // namespace paddle_mobile + +#endif // BEAM_SEARCH_OP diff --git a/src/operators/bilinear_interp_op.h b/src/operators/bilinear_interp_op.h index 2bb61d129d5ba45900f1c67b8c202e958a004bb7..2fee40859b071a00464f7a982f5bd4b8b2139df9 100644 --- a/src/operators/bilinear_interp_op.h +++ b/src/operators/bilinear_interp_op.h @@ -34,8 +34,7 @@ class BilinearOp : public framework::OperatorWithKernel< public: BilinearOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, BilinearInterpParam, operators::BilinearInterpKernel>( diff --git a/src/operators/box_coder_op.h b/src/operators/box_coder_op.h index 3a3048c6624996892333a71773c33ee2f6e18e0a..417783ca939359929841d4e90e70d5ca9daf836b 100644 --- a/src/operators/box_coder_op.h +++ b/src/operators/box_coder_op.h @@ -34,8 +34,7 @@ class BoxCoderOp : public framework::OperatorWithKernel< public: BoxCoderOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::BoxCoderKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/cast_op.h b/src/operators/cast_op.h index 3bb91972cc93c3c28354897ad0318985344a3d5c..a244d5cfaff4d2cd5eb6c807138e263fa78e593f 100644 --- a/src/operators/cast_op.h +++ b/src/operators/cast_op.h @@ -31,7 +31,7 @@ class CastOp : public framework::OperatorWithKernel< public: CastOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::CastKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/concat_op.h b/src/operators/concat_op.h index a01e066edd1082bc109ba7eb0f31a2ac42ab865a..94c402cd8566485b8fbf5e5e96dce5a5afdcd0f3 100644 --- a/src/operators/concat_op.h +++ b/src/operators/concat_op.h @@ -30,7 +30,7 @@ class ConcatOp : public framework::OperatorWithKernel< public: ConcatOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::ConcatKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/conv_op.cpp b/src/operators/conv_op.cpp index 2c70f42f56530c2d21252d6b51c228e7c49ca8bf..ad778b1fef7fe400e1df645703cf3ebfb1b22727 100644 --- a/src/operators/conv_op.cpp +++ b/src/operators/conv_op.cpp @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "framework/op_proto_maker.h" #include "framework/op_registry.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -39,9 +39,9 @@ void ConvOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/conv_op.h b/src/operators/conv_op.h index 1b8bd70805ccff8946c1ab12a207618849fc9ca4..f023e60e72b4bee52dfae816f352f06ed5297196 100644 --- a/src/operators/conv_op.h +++ b/src/operators/conv_op.h @@ -30,7 +30,7 @@ class ConvOp : public framework::OperatorWithKernel< public: ConvOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::ConvKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/conv_transpose_op.h b/src/operators/conv_transpose_op.h index 4e6464b3a4b19316315eb68739c40654de3eb018..89dab62e04df368b17703addbeb40fb9a365adac 100644 --- a/src/operators/conv_transpose_op.h +++ b/src/operators/conv_transpose_op.h @@ -31,8 +31,7 @@ class ConvOpTranspose : public framework::OperatorWithKernel< public: ConvOpTranspose(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, ConvTransposeParam, operators::ConvTransposeKernel>( diff --git a/src/operators/crf_op.h b/src/operators/crf_op.h index dca481bb2dd08dc65fb94e41d0573277c9b143c7..fb0fd908898bb64bb84ee319d2b285de4739475b 100644 --- a/src/operators/crf_op.h +++ b/src/operators/crf_op.h @@ -33,7 +33,7 @@ class CrfOp : public framework::OperatorWithKernel< public: CrfOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::CrfKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/depthwise_conv_op.cpp b/src/operators/depthwise_conv_op.cpp index 2e7f193c5c9f66668411bb115da9d3cd980f8a6b..0e74654e1f661d55a263f9f9b57a1ba2a32dfd74 100644 --- a/src/operators/depthwise_conv_op.cpp +++ b/src/operators/depthwise_conv_op.cpp @@ -19,7 +19,7 @@ limitations under the License. */ #include "framework/op_proto_maker.h" #include "framework/op_registry.h" #include "operators/conv_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -40,9 +40,9 @@ void DepthwiseConvOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/depthwise_conv_op.h b/src/operators/depthwise_conv_op.h index 26253e0e0a7d3c52808a691d4257e7074e1da6e2..d1cbeeab06182814c298989cd6e4c25d38405252 100644 --- a/src/operators/depthwise_conv_op.h +++ b/src/operators/depthwise_conv_op.h @@ -30,8 +30,7 @@ class DepthwiseConvOp : public framework::OperatorWithKernel< public: DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::ConvKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/dequantize_op.h b/src/operators/dequantize_op.h index 101b77e95a6535cf824a40b743e34064835eb069..81ab62bee8cff8ea37e4178ec266fec5b88174a0 100644 --- a/src/operators/dequantize_op.h +++ b/src/operators/dequantize_op.h @@ -32,8 +32,7 @@ class DequantizeOp public: DequantizeOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, DequantizeKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/dropout_op.h b/src/operators/dropout_op.h index ce8acd5966439808f7a03f18cf3d29a1b5c0487e..132b94af692d9d8f3cb2fa4b146e8265893db231 100644 --- a/src/operators/dropout_op.h +++ b/src/operators/dropout_op.h @@ -34,7 +34,7 @@ class DropoutOp : public framework::OperatorWithKernel< public: DropoutOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::DropoutKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/elementwise_add_op.cpp b/src/operators/elementwise_add_op.cpp index 281cd3d5084a1a15502e1e06865e1024d3b2b639..6fde477f228d140f28525989bdbba564ed88854d 100644 --- a/src/operators/elementwise_add_op.cpp +++ b/src/operators/elementwise_add_op.cpp @@ -23,6 +23,7 @@ template void ElementwiseAddOp::InferShape() const { auto x_dim = this->param_.InputX()->dims(); this->param_.Out()->Resize(x_dim); + this->param_.Out()->set_lod(this->param_.InputX()->lod()); } } // namespace operators diff --git a/src/operators/elementwise_add_op.h b/src/operators/elementwise_add_op.h index a853b40ff7ccf323911f2ea1bf6e23d67d111db2..7819765813e463a4d916a95c3882768603471fbf 100644 --- a/src/operators/elementwise_add_op.h +++ b/src/operators/elementwise_add_op.h @@ -32,7 +32,7 @@ class ElementwiseAddOp : public framework::OperatorWithKernel< ElementwiseAddOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, ElementwiseAddParam, operators::ElementwiseAddKernel>( diff --git a/src/operators/elementwise_mul_op.h b/src/operators/elementwise_mul_op.h index 991b03a486d65c720b88b80a1aece417b9919d3d..53a90180b69124fddf4ce5091be0d169c9dcdfb1 100644 --- a/src/operators/elementwise_mul_op.h +++ b/src/operators/elementwise_mul_op.h @@ -32,7 +32,7 @@ class ElementwiseMulOp : public framework::OperatorWithKernel< ElementwiseMulOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, ElementwiseMulParam, operators::ElementwiseMulKernel>( diff --git a/src/operators/elementwise_sub_op.h b/src/operators/elementwise_sub_op.h index 2edd2581a9d3929a29459df60f514132796a53e2..ce3b310ef334356b54d4e03f2df90eeef4303e5f 100644 --- a/src/operators/elementwise_sub_op.h +++ b/src/operators/elementwise_sub_op.h @@ -32,7 +32,7 @@ class ElementwiseSubOp : public framework::OperatorWithKernel< ElementwiseSubOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, ElementwiseSubParam, operators::ElementwiseSubKernel>( diff --git a/src/operators/feed_op.cpp b/src/operators/feed_op.cpp index 4e496fb51d16c47d801eabada7c36dbdefdd2140..9e0b037c8dff4e4ea27d6f2f3155d06c9ed4821f 100644 --- a/src/operators/feed_op.cpp +++ b/src/operators/feed_op.cpp @@ -21,7 +21,8 @@ template void FeedOp::InferShape() const { auto out_dims = this->param_.Out()->dims(); out_dims[0] = this->param_.BatchSize(); - auto input_dims = this->param_.InputX()->dims(); + int col = this->param_.Col(); + auto input_dims = this->param_.InputX()->at(col).dims(); if (input_dims.size() == 4) { this->param_.Out()->Resize(input_dims); } else { diff --git a/src/operators/feed_op.h b/src/operators/feed_op.h index 57932474184fd5431e5b6ac5756ab28faa2b1b9e..fda259b58556f55cd0e98af71295f4827be80006 100644 --- a/src/operators/feed_op.h +++ b/src/operators/feed_op.h @@ -31,7 +31,7 @@ class FeedOp public: FeedOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, FeedKernel>( diff --git a/src/operators/fetch_op.cpp b/src/operators/fetch_op.cpp index 50e53c30cfd06a8fae8c9e18dd4aa985a056a13e..2d0ac82ec8a1d9338b4e1784d19587cf09fdba74 100644 --- a/src/operators/fetch_op.cpp +++ b/src/operators/fetch_op.cpp @@ -18,8 +18,9 @@ namespace operators { template void FetchOp::InferShape() const { + int col = this->param_.Col(); auto x_dims = this->param_.InputX()->dims(); - this->param_.Out()->Resize(x_dims); + this->param_.Out()->at(col).Resize(x_dims); } } // namespace operators diff --git a/src/operators/fetch_op.h b/src/operators/fetch_op.h index f92c66a05f121b3f6b78c244dd01d81393fa5c68..72c8e1997f54c85e5a59c1ec318cf09230a66196 100644 --- a/src/operators/fetch_op.h +++ b/src/operators/fetch_op.h @@ -30,7 +30,7 @@ class FetchOp public: FetchOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, FetchKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/fill_constant_op.h b/src/operators/fill_constant_op.h index 3b0acff2ed8aa95c2dd22a8d178952ca4ecf22ca..0a51f8494d56b2490f073bfee9a71950f6075647 100644 --- a/src/operators/fill_constant_op.h +++ b/src/operators/fill_constant_op.h @@ -31,11 +31,10 @@ class FillConstantOp : public framework::OperatorBase { public: FillConstantOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap attrs, - std::shared_ptr scope) + const framework::AttributeMap attrs, framework::Scope *scope) : framework::OperatorBase(type, inputs, outputs, attrs, scope), - param_(inputs, outputs, attrs, scope.get()) {} + param_(inputs, outputs, attrs, scope) {} void RunImpl() { auto data_type = static_cast<_PaddleMobile__Framework__Proto__VarType__Type>( diff --git a/src/operators/flatten_op.h b/src/operators/flatten_op.h index a7a91e60701cf559cb35238aa2966c02c869e844..daad2d82d8e37a92bcdbbd33e64ed7e505b5b2a9 100644 --- a/src/operators/flatten_op.h +++ b/src/operators/flatten_op.h @@ -49,8 +49,7 @@ class FlattenOp : public framework::OperatorWithKernel< public: FlattenOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::FlattenKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/fusion_conv_add_add_prelu_op.cpp b/src/operators/fusion_conv_add_add_prelu_op.cpp deleted file mode 100644 index 2f3d29dc74ed3a852b5c41a64d46b8710ebec599..0000000000000000000000000000000000000000 --- a/src/operators/fusion_conv_add_add_prelu_op.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDADDPRELU_OP - -#include "operators/fusion_conv_add_add_prelu_op.h" -#include "operators/math/conv_func.h" - -namespace paddle_mobile { -namespace operators { - -template -void FusionConvAddAddPReluOp::InferShape() const { - auto in_dims = this->param_.Input()->dims(); - auto filter_dims = this->param_.Filter()->dims(); - const std::vector &strides = this->param_.Strides(); - std::vector paddings = this->param_.Paddings(); - int groups = this->param_.Groups(); - std::vector dilations = this->param_.Dilations(); - - PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && - dilations.size() == paddings.size() && - paddings.size() == strides.size()), - "ConvParam is not suitable"); - - std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); - } - framework::DDim ddim = framework::make_ddim(output_shape); - this->param_.Output()->Resize(ddim); -} - -} // namespace operators -} // namespace paddle_mobile - -namespace ops = paddle_mobile::operators; -REGISTER_FUSION_MATCHER(fusion_conv_add_add_prelu, - ops::FusionConvAddAddPReluOpMatcher); - -#ifdef PADDLE_MOBILE_CPU -REGISTER_OPERATOR_CPU(fusion_conv_add_add_prelu, ops::FusionConvAddAddPReluOp); -#endif -#ifdef PADDLE_MOBILE_FPGA -REGISTER_OPERATOR_FPGA(fusion_conv_add_add_prelu, ops::FusionConvAddAddPReluOp); -#endif - -#endif // FUSION_CONVADDADDPRELU_OP diff --git a/src/operators/fusion_conv_add_add_prelu_op.h b/src/operators/fusion_conv_add_add_prelu_op.h deleted file mode 100644 index 4ec76b500812f95eb64e27564d0e63b2c1b2c2d3..0000000000000000000000000000000000000000 --- a/src/operators/fusion_conv_add_add_prelu_op.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDADDPRELU_OP - -#pragma once - -#include -#include -#include -#include "framework/operator.h" -#include "framework/program/program-optimize/fusion_op_register.h" -#include "operators/kernel/conv_add_add_prelu_kernel.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -class FusionConvAddAddPReluOpMatcher : public framework::FusionOpMatcher { - public: - FusionConvAddAddPReluOpMatcher() { - node_ = framework::Node(G_OP_TYPE_CONV); - node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > - std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > - std::make_shared(G_OP_TYPE_PRELU); - } - - void FolderNodes( - framework::Node *node, - std::vector> *removed_nodes) { - node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, - {{"Y", "Y"}, {"Out", "addOut"}, {"X", "addX"}}}, - {G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}}}, - - removed_nodes); - } - std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU; } - - std::vector> NeedCheck() { - DLOG << " conv add add prelu check add X "; - return {{2, "Y"}, {2, "X"}}; - } -}; - -template -class FusionConvAddAddPReluOp - : public framework::OperatorWithKernel< - DeviceType, FusionConvAddAddPReluParam, - operators::ConvAddAddPReluKernel> { - public: - FusionConvAddAddPReluOp(const string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) - : framework::OperatorWithKernel< - DeviceType, FusionConvAddAddPReluParam, - operators::ConvAddAddPReluKernel>( - type, inputs, outputs, attrs, scope) {} - void InferShape() const override; - - protected: -}; - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/fusion_conv_add_bn_op.cpp b/src/operators/fusion_conv_add_bn_op.cpp index e8daba7e9ba209cf078323ea79dd6f6a9b6e8200..27e3c04d62c29abe69adef7457bc633d294e2cdc 100644 --- a/src/operators/fusion_conv_add_bn_op.cpp +++ b/src/operators/fusion_conv_add_bn_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVADDBN_OP #include "operators/fusion_conv_add_bn_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -36,9 +36,9 @@ void FusionConvAddBNOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/fusion_conv_add_bn_op.h b/src/operators/fusion_conv_add_bn_op.h index c4260aef42f9d74cc1f7069c3ae26ccf58f75280..0618f8051286f634c71984a634de399cda2ffec1 100644 --- a/src/operators/fusion_conv_add_bn_op.h +++ b/src/operators/fusion_conv_add_bn_op.h @@ -20,8 +20,8 @@ limitations under the License. */ #include #include "framework/operator.h" #include "framework/program/program-optimize/fusion_op_register.h" -#include "op_param.h" #include "operators/kernel/conv_add_bn_kernel.h" +#include "operators/op_param.h" namespace paddle_mobile { namespace operators { @@ -59,7 +59,7 @@ class FusionConvAddBNOp : public framework::OperatorWithKernel< FusionConvAddBNOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionConvAddBNParam, operators::ConvAddBNKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_conv_add_bn_relu_op.cpp b/src/operators/fusion_conv_add_bn_relu_op.cpp index b9bc948fe0e77741a36f959e29eb2a4c82e82b72..8f162a2d29de32340b8f7f3fe3094a230212929d 100644 --- a/src/operators/fusion_conv_add_bn_relu_op.cpp +++ b/src/operators/fusion_conv_add_bn_relu_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVADDBNRELU_OP #include "operators/fusion_conv_add_bn_relu_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -36,9 +36,9 @@ void FusionConvAddBNReluOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/fusion_conv_add_bn_relu_op.h b/src/operators/fusion_conv_add_bn_relu_op.h index 6ecc9bdc4a90530221c70651c52457874e3eaaa8..9dd2fd406a6696310a73369c33d55f4d92e2b50c 100644 --- a/src/operators/fusion_conv_add_bn_relu_op.h +++ b/src/operators/fusion_conv_add_bn_relu_op.h @@ -61,7 +61,7 @@ class FusionConvAddBNReluOp FusionConvAddBNReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionConvAddBNReluParam, operators::ConvAddBNReluKernel>( diff --git a/src/operators/fusion_conv_add_op.cpp b/src/operators/fusion_conv_add_op.cpp index 731bb631bb98490d580e0c6fe28c24312f6ccb57..49cf29c38e40f5a55fa0546e988d2860a6842f6b 100644 --- a/src/operators/fusion_conv_add_op.cpp +++ b/src/operators/fusion_conv_add_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVADD_OP #include "operators/fusion_conv_add_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -36,9 +36,9 @@ void FusionConvAddOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/fusion_conv_add_op.h b/src/operators/fusion_conv_add_op.h index eef143ce8716ce856784bb01dd3d58a26746b4e8..22ecab45e694278d22e3a736e415b5a34af623da 100644 --- a/src/operators/fusion_conv_add_op.h +++ b/src/operators/fusion_conv_add_op.h @@ -50,8 +50,7 @@ class FusionConvAddOp : public framework::OperatorWithKernel< public: FusionConvAddOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::ConvAddKernel>( diff --git a/src/operators/fusion_conv_add_prelu_op.cpp b/src/operators/fusion_conv_add_prelu_op.cpp deleted file mode 100644 index 9273af388c2c0a8644b29e1f40a5238b0e092523..0000000000000000000000000000000000000000 --- a/src/operators/fusion_conv_add_prelu_op.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDPRELU_OP - -#include "operators/fusion_conv_add_prelu_op.h" -#include "operators/math/conv_func.h" - -namespace paddle_mobile { -namespace operators { - -template -void FusionConvAddPReluOp::InferShape() const { - auto in_dims = this->param_.Input()->dims(); - auto filter_dims = this->param_.Filter()->dims(); - const std::vector &strides = this->param_.Strides(); - std::vector paddings = this->param_.Paddings(); - int groups = this->param_.Groups(); - std::vector dilations = this->param_.Dilations(); - - PADDLE_MOBILE_ENFORCE((in_dims.size() == filter_dims.size() && - dilations.size() == paddings.size() && - paddings.size() == strides.size()), - "ConvParam is not suitable"); - - std::vector output_shape({in_dims[0], filter_dims[0]}); - for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); - } - framework::DDim ddim = framework::make_ddim(output_shape); - this->param_.Output()->Resize(ddim); -} - -} // namespace operators -} // namespace paddle_mobile - -namespace ops = paddle_mobile::operators; -REGISTER_FUSION_MATCHER(fusion_conv_add_prelu, - ops::FusionConvAddPReluOpMatcher); - -#ifdef PADDLE_MOBILE_CPU -REGISTER_OPERATOR_CPU(fusion_conv_add_prelu, ops::FusionConvAddPReluOp); -#endif -#ifdef PADDLE_MOBILE_FPGA -REGISTER_OPERATOR_FPGA(fusion_conv_add_prelu, ops::FusionConvAddPReluOp); -#endif - -#endif diff --git a/src/operators/fusion_conv_add_prelu_op.h b/src/operators/fusion_conv_add_prelu_op.h deleted file mode 100644 index fc1143099e16b8b7f7c44d7fe5a5694a278a1906..0000000000000000000000000000000000000000 --- a/src/operators/fusion_conv_add_prelu_op.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDPRELU_OP - -#pragma once - -#include -#include -#include "framework/operator.h" -#include "framework/program/program-optimize/fusion_op_register.h" -#include "operators/kernel/conv_add_prelu_kernel.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -class FusionConvAddPReluOpMatcher : public framework::FusionOpMatcher { - public: - FusionConvAddPReluOpMatcher() { - node_ = framework::Node(G_OP_TYPE_CONV); - node_ > std::make_shared(G_OP_TYPE_ELEMENTWISE_ADD) > - std::make_shared(G_OP_TYPE_PRELU); - } - - void FolderNodes( - framework::Node *node, - std::vector> *removed_nodes) { - node->Folder(node_.Depth(), Type(), - {{G_OP_TYPE_ELEMENTWISE_ADD, {{"Y", "Y"}}}, - {G_OP_TYPE_PRELU, {{"Alpha", "Alpha"}}}}, - removed_nodes); - } - std::string Type() { return G_OP_TYPE_FUSION_CONV_ADD_PRELU; } -}; - -template -class FusionConvAddPReluOp - : public framework::OperatorWithKernel< - DeviceType, FusionConvAddPReluParam, - operators::ConvAddPReluKernel> { - public: - FusionConvAddPReluOp(const string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) - : framework::OperatorWithKernel< - DeviceType, FusionConvAddPReluParam, - operators::ConvAddPReluKernel>(type, inputs, outputs, - attrs, scope) {} - - void InferShape() const override; - - protected: -}; - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/fusion_conv_add_relu_op.cpp b/src/operators/fusion_conv_add_relu_op.cpp index bb4b6666a881de0989d43840806b9d5d720b3b66..163dfba3cc8706dac96697974ef7224b3f625ae1 100644 --- a/src/operators/fusion_conv_add_relu_op.cpp +++ b/src/operators/fusion_conv_add_relu_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVADDRELU_OP #include "operators/fusion_conv_add_relu_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -36,9 +36,9 @@ void FusionConvAddReluOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); this->param_.Output()->Resize(ddim); diff --git a/src/operators/fusion_conv_add_relu_op.h b/src/operators/fusion_conv_add_relu_op.h index c4cd61016c6da1100819b3531bd41466726f292a..7a1cfd19414a18e727fb2003b603a611ebc603e5 100644 --- a/src/operators/fusion_conv_add_relu_op.h +++ b/src/operators/fusion_conv_add_relu_op.h @@ -51,7 +51,7 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel< FusionConvAddReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionConvAddReluParam, operators::ConvAddReluKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_conv_bn_add_relu_op.cpp b/src/operators/fusion_conv_bn_add_relu_op.cpp index 9a3926353319aa267814097d93a6d9b1fa20bd2d..c2bb2c744d5599558f14e2f1d169b00a1492e135 100644 --- a/src/operators/fusion_conv_bn_add_relu_op.cpp +++ b/src/operators/fusion_conv_bn_add_relu_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVBNADDRELU_OP #include "operators/fusion_conv_bn_add_relu_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -36,9 +36,9 @@ void FusionConvBNAddReluOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/fusion_conv_bn_add_relu_op.h b/src/operators/fusion_conv_bn_add_relu_op.h index 303668a89bf7869e72a4b546c5d96be24b26c4ec..676d30ce2698154b104c9dc100ea404d1cc0ba73 100644 --- a/src/operators/fusion_conv_bn_add_relu_op.h +++ b/src/operators/fusion_conv_bn_add_relu_op.h @@ -67,7 +67,7 @@ class FusionConvBNAddReluOp FusionConvBNAddReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionConvBNAddReluParam, operators::ConvBNAddReluKernel>( diff --git a/src/operators/fusion_conv_bn_op.cpp b/src/operators/fusion_conv_bn_op.cpp index 7786cd713b5f838e22aa3080697d551609d81036..4939123a77a072ea410bfa96547b8a0ed276c28d 100644 --- a/src/operators/fusion_conv_bn_op.cpp +++ b/src/operators/fusion_conv_bn_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVBN_OP #include "operators/fusion_conv_bn_op.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -35,9 +36,9 @@ void FusionConvBNOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/fusion_conv_bn_op.h b/src/operators/fusion_conv_bn_op.h index f393928665301da0dd0076b33e81ca79791794f7..385bb539fd9f5bb1e1e94e8e26830538daed5297 100644 --- a/src/operators/fusion_conv_bn_op.h +++ b/src/operators/fusion_conv_bn_op.h @@ -56,8 +56,7 @@ class FusionConvBNOp : public framework::OperatorWithKernel< public: FusionConvBNOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::ConvBNKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/fusion_conv_bn_relu_op.cpp b/src/operators/fusion_conv_bn_relu_op.cpp index 54c9f85cbb7dc00bd0df5747caac8fd2ee9e2782..0e8eec65f2e46e1314c11b7f6bceade861445ef6 100644 --- a/src/operators/fusion_conv_bn_relu_op.cpp +++ b/src/operators/fusion_conv_bn_relu_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_CONVBNRELU_OP #include "operators/fusion_conv_bn_relu_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -36,9 +36,9 @@ void FusionConvBNReluOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/fusion_conv_bn_relu_op.h b/src/operators/fusion_conv_bn_relu_op.h index 9bc534fe333c76e8f533c904560b8228760c66e5..2f49df081cef8d5d6ac785f7512e3737d9c7593d 100644 --- a/src/operators/fusion_conv_bn_relu_op.h +++ b/src/operators/fusion_conv_bn_relu_op.h @@ -58,7 +58,7 @@ class FusionConvBNReluOp : public framework::OperatorWithKernel< FusionConvBNReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionConvBNReluParam, operators::ConvBNReluKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_deconv_add_bn_op.h b/src/operators/fusion_deconv_add_bn_op.h index f7f9b9e2094a7228c944b70b88ae3105ae9f37e8..618545044136e42e750fd4c71ce96bd861954b71 100644 --- a/src/operators/fusion_deconv_add_bn_op.h +++ b/src/operators/fusion_deconv_add_bn_op.h @@ -57,7 +57,7 @@ class FusionDeconvAddBNOp : public framework::OperatorWithKernel< FusionDeconvAddBNOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDeconvAddBNParam, operators::DeconvAddBNKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_deconv_add_bn_relu_op.h b/src/operators/fusion_deconv_add_bn_relu_op.h index 97070ef01e544839be8eab6ddba21c43dfa9a26e..1c6cfd7318e48cad16e1d274b5724c832c70d8c8 100644 --- a/src/operators/fusion_deconv_add_bn_relu_op.h +++ b/src/operators/fusion_deconv_add_bn_relu_op.h @@ -59,7 +59,7 @@ class FusionDeconvAddBNReluOp FusionDeconvAddBNReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDeconvAddBNReluParam, operators::DeconvAddBNReluKernel>( diff --git a/src/operators/fusion_deconv_add_op.h b/src/operators/fusion_deconv_add_op.h index 45d6d33aad683b1d45480c304cf44ef10349cdf3..406f81318a28889f066b03eb6cfeb954939b0f1a 100644 --- a/src/operators/fusion_deconv_add_op.h +++ b/src/operators/fusion_deconv_add_op.h @@ -49,7 +49,7 @@ class FusionDeconvAddOp : public framework::OperatorWithKernel< FusionDeconvAddOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDeconvAddParam, operators::DeconvAddKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_deconv_add_relu_op.h b/src/operators/fusion_deconv_add_relu_op.h index eeef6d3958746edfdc114192c3b923db1e102ced..735e126b033c5872ed66900acc5c56dd76b8ad85 100644 --- a/src/operators/fusion_deconv_add_relu_op.h +++ b/src/operators/fusion_deconv_add_relu_op.h @@ -51,7 +51,7 @@ class FusionDeconvAddReluOp FusionDeconvAddReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDeconvAddReluParam, operators::DeconvAddReluKernel>( diff --git a/src/operators/fusion_deconv_bn_relu_op.h b/src/operators/fusion_deconv_bn_relu_op.h index ad0920ebd69b1a13ebc0e85f2c5f6008379715da..92bb97445d1442056843efb1fd66fa3fb1e54511 100644 --- a/src/operators/fusion_deconv_bn_relu_op.h +++ b/src/operators/fusion_deconv_bn_relu_op.h @@ -56,7 +56,7 @@ class FusionDeconvBNReluOp FusionDeconvBNReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDeconvBNReluParam, operators::DeconvBNReluKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_deconv_relu_op.h b/src/operators/fusion_deconv_relu_op.h index e87d5d3798930d745b82c8e5a3cca793c12ee4b1..c290a8da081591d0eddf3ef075ae57d3869b7725 100644 --- a/src/operators/fusion_deconv_relu_op.h +++ b/src/operators/fusion_deconv_relu_op.h @@ -48,7 +48,7 @@ class FusionDeconvReluOp : public framework::OperatorWithKernel< FusionDeconvReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDeconvReluParam, operators::DeconvReluKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_dequant_add_bn_op.h b/src/operators/fusion_dequant_add_bn_op.h index f619c607f349da489e35efa0f3a00f069471c67c..b838b544ce249029bfdbad77f62c8f393006ebd2 100644 --- a/src/operators/fusion_dequant_add_bn_op.h +++ b/src/operators/fusion_dequant_add_bn_op.h @@ -60,7 +60,7 @@ class FusionDequantAddBNOp FusionDequantAddBNOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDequantAddBNParam, operators::FusionDequantAddBNKernel>( diff --git a/src/operators/fusion_dequant_add_bn_relu_op.h b/src/operators/fusion_dequant_add_bn_relu_op.h index ba7ba50358d35b38cf621f374ae28787fa82e1a6..e2762923c511858d5cea77f3301919d4c0b8fa4b 100644 --- a/src/operators/fusion_dequant_add_bn_relu_op.h +++ b/src/operators/fusion_dequant_add_bn_relu_op.h @@ -62,7 +62,7 @@ class FusionDequantAddBNReluOp const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDequantAddBNParam, operators::FusionDequantAddBNReluKernel>( diff --git a/src/operators/fusion_dequant_add_bn_relu_quant_op.h b/src/operators/fusion_dequant_add_bn_relu_quant_op.h index e03963975c362f99947cf8fd043499a62a3029c8..6caa8daeb3f54312462b185f78ff07fbdf69cd7d 100644 --- a/src/operators/fusion_dequant_add_bn_relu_quant_op.h +++ b/src/operators/fusion_dequant_add_bn_relu_quant_op.h @@ -62,7 +62,7 @@ class FusionDequantAddBNReluQuantOp const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDequantAddBNReluQuantParam, operators::FusionDequantAddBNReluQuantKernel>( @@ -109,7 +109,7 @@ class FusionDequantAddBNQuantOp const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDequantAddBNQuantParam, operators::FusionDequantAddBNQuantKernel>( diff --git a/src/operators/fusion_dequant_bn_op.h b/src/operators/fusion_dequant_bn_op.h index 496bba73fed71f20f0f41b5cd68788f1e4bf3bba..ac2237b77acceab09ad9120a7d177ea5e0051697 100644 --- a/src/operators/fusion_dequant_bn_op.h +++ b/src/operators/fusion_dequant_bn_op.h @@ -58,7 +58,7 @@ class FusionDequantBNOp : public framework::OperatorWithKernel< FusionDequantBNOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDequantBNParam, operators::FusionDequantBNKernel>( @@ -87,7 +87,7 @@ class FusionDequantBNReluOp FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDequantBNParam, operators::FusionDequantBNReluKernel>( diff --git a/src/operators/fusion_dequant_bn_relu_op.h b/src/operators/fusion_dequant_bn_relu_op.h index 7e12cb1ae0cdaa9ea97ae84d76fc190920331c1b..be3b5293a334ca8bf275deabbccf0693679cde18 100644 --- a/src/operators/fusion_dequant_bn_relu_op.h +++ b/src/operators/fusion_dequant_bn_relu_op.h @@ -59,7 +59,7 @@ class FusionDequantBNReluOp FusionDequantBNReluOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDequantBNReluParam, operators::FusionDequantBNReluKernel>( diff --git a/src/operators/fusion_dwconv_bn_relu_op.cpp b/src/operators/fusion_dwconv_bn_relu_op.cpp index f5040987e42f9c0b3068d730a9926b9fcff8c8c3..d4c04f67fc637266cf95af2e7fe518682e212d98 100644 --- a/src/operators/fusion_dwconv_bn_relu_op.cpp +++ b/src/operators/fusion_dwconv_bn_relu_op.cpp @@ -15,7 +15,7 @@ limitations under the License. */ #ifdef FUSION_DWCONVBNRELU_OP #include "operators/fusion_dwconv_bn_relu_op.h" -#include "operators/math/conv_func.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" namespace paddle_mobile { namespace operators { @@ -36,9 +36,9 @@ void FusionDWConvBNReluOp::InferShape() const { std::vector output_shape({in_dims[0], filter_dims[0]}); for (size_t i = 0; i < strides.size(); ++i) { - output_shape.push_back( - math::ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], - paddings[i], strides[i])); + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); } framework::DDim ddim = framework::make_ddim(output_shape); diff --git a/src/operators/fusion_dwconv_bn_relu_op.h b/src/operators/fusion_dwconv_bn_relu_op.h index d7a74d896e904971e21c28fab29771b34a049921..0fb2e5c70cabb3cdadc56b1fc2f50148f5b42f0e 100644 --- a/src/operators/fusion_dwconv_bn_relu_op.h +++ b/src/operators/fusion_dwconv_bn_relu_op.h @@ -59,7 +59,7 @@ class FusionDWConvBNReluOp FusionDWConvBNReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionDWConvBNReluParam, operators::DWConvBNReluKernel>(type, inputs, outputs, diff --git a/src/operators/fusion_elementwise_add_relu_op.h b/src/operators/fusion_elementwise_add_relu_op.h index 6434e726ccd8df8cf97736bfa65904674c73ad03..c90d4e041ede54988cd1e9089991c9df8c594ab4 100644 --- a/src/operators/fusion_elementwise_add_relu_op.h +++ b/src/operators/fusion_elementwise_add_relu_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #pragma once #include +#include #include "framework/operator.h" #include "framework/program/program-optimize/fusion_op_register.h" #include "operators/kernel/elementwise_add_relu_kernel.h" @@ -50,7 +51,7 @@ class FusionElementwiseAddReluOp FusionElementwiseAddReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, ElementwiseAddReluParam, operators::ElementwiseAddReluKernel>( diff --git a/src/operators/fusion_fc_op.h b/src/operators/fusion_fc_op.h index 26cb40aac8e47203f125417e1f6b5df75d7835b5..a88add4584060079fd437c3fff2e4228571d186f 100644 --- a/src/operators/fusion_fc_op.h +++ b/src/operators/fusion_fc_op.h @@ -50,8 +50,7 @@ class FusionFcOp : public framework::OperatorWithKernel< public: FusionFcOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::FusionFcKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/fusion_fc_relu_op.h b/src/operators/fusion_fc_relu_op.h index 7324f94138e59c4a4a93fe2658b38ddbdf6fa651..253335c8f258aa8a21b639378744d9bd4767a344 100644 --- a/src/operators/fusion_fc_relu_op.h +++ b/src/operators/fusion_fc_relu_op.h @@ -49,8 +49,7 @@ class FusionFcReluOp : public framework::OperatorWithKernel< public: FusionFcReluOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, FusionFcReluParam, operators::FusionFcReluKernel>(type, inputs, outputs, diff --git a/src/operators/gru_op.h b/src/operators/gru_op.h index 5e66b497af15c498e2af5ff5903ef88a16db1832..80bbd7c22233edba9902f636ae3c11c484761e2b 100644 --- a/src/operators/gru_op.h +++ b/src/operators/gru_op.h @@ -33,7 +33,7 @@ class GruOp : public framework::OperatorWithKernel< public: GruOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::GruKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/gru_unit_op.h b/src/operators/gru_unit_op.h index 4188662d05e79a97fa2f0dba62303391ae8e0d70..8821212bfa3232e2713dfae86513693c09af8290 100644 --- a/src/operators/gru_unit_op.h +++ b/src/operators/gru_unit_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #pragma once +#include #include "framework/operator.h" #include "operators/kernel/gru_unit_kernel.h" #include "operators/op_param.h" @@ -30,10 +31,10 @@ class GruUnitOp : public framework::OperatorWithKernel< public: GruUnitOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::GruUnitKernel>( - type, inputs, outputs, attrs, scope){}; + type, inputs, outputs, attrs, scope) {} void InferShape() const override; }; diff --git a/src/operators/im2sequence_op.h b/src/operators/im2sequence_op.h index 036b496ca8293432aa30ae86542e78880143f086..4361380b8f7a4c4c89b63e7c1125e0c60dc79eba 100644 --- a/src/operators/im2sequence_op.h +++ b/src/operators/im2sequence_op.h @@ -31,8 +31,7 @@ class Im2SequenceOp : public framework::OperatorWithKernel< public: Im2SequenceOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, Im2SequenceParam, operators::Im2SequenceKernel>(type, inputs, outputs, diff --git a/src/operators/increment_op.h b/src/operators/increment_op.h index 5212630cc4d01ca6d432f5e340d3cf0fd89782b5..e0455b911342f44be2001fa858dca500bab2b591 100644 --- a/src/operators/increment_op.h +++ b/src/operators/increment_op.h @@ -32,8 +32,7 @@ class IncrementOp public: IncrementOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, IncrementKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/is_empty_op.h b/src/operators/is_empty_op.h index 45af2646a7a8a2691868ea204a8602a34898412d..1f31f25796fcb1c33f1b94a1dac81d714d1bbca5 100644 --- a/src/operators/is_empty_op.h +++ b/src/operators/is_empty_op.h @@ -31,8 +31,7 @@ class IsEmptyOp public: IsEmptyOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, IsEmptyKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/kernel/arm/beam_search_decode_kernel.cpp b/src/operators/kernel/arm/beam_search_decode_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f22c032347faa2eb85c6594df54a5f91f214903c --- /dev/null +++ b/src/operators/kernel/arm/beam_search_decode_kernel.cpp @@ -0,0 +1,277 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_DECODE_OP + +#include "operators/kernel/beam_search_decode_kernel.h" +#include "framework/data_type.h" + +namespace paddle_mobile { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using LoDTensorArray = framework::LoDTensorArray; + +// all the lod have 2 levels. +// The first is source level, the second is sentence level. +// source level describe how many prefixes (branchs) for each source sentece +// (beam). sentence level describe how these candidates belong to the prefixes. +const size_t kSourceLevel = 0; +const size_t kSentenceLevel = 1; + +template +struct Sentence { + std::vector word_ids; + std::vector scores; +}; + +template +using SentenceVector = std::vector>; + +template +struct BeamSearchDecoder { + BeamSearchDecoder(size_t beam_size, int end_id) + : beam_size_(beam_size), end_id_(end_id) {} + + /** + * convert the result sentence_vector for each source sentence into two + * LodTensor. + * One is all candidate sentences with word id, one is all candidate sentences + * with word score. + * Param: + * sentence_vector_list: sentence_vector for each source sentence. + * id_tensor: result LoDTensor for sentences of id. + * score_tensor: result LoDTensor for sentences of score. + * reverse: whether ids of sentence in sentence_vector_list is reversed + * sort_by_score: whether to sort hypotheses of each sentence by scores. + */ + void ConvertSentenceVectorToLodTensor( + std::vector> sentence_vector_list, LoDTensor* id_tensor, + LoDTensor* score_tensor, bool reverse = true, + bool sort_by_score = true) const; + + /** + * Gather the hypotheses for each source sentence by backtrace though the + * LoDTensorArray step_ids whose lods reserve the path in the tree. + */ + void Backtrace(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, LoDTensor* id_tensor, + LoDTensor* score_tensor) const; + + size_t beam_size_; + int end_id_; +}; + +template +void BeamSearchDecoder::ConvertSentenceVectorToLodTensor( + std::vector> sentence_vector_list, LoDTensor* id_tensor, + LoDTensor* score_tensor, bool reverse, bool sort_by_score) const { + size_t src_num = sentence_vector_list.size(); + + PADDLE_MOBILE_ENFORCE(src_num > 0, "src_num should be larger than 0"); + + std::vector source_level_lod = {0}; + std::vector sentence_level_lod = {0}; + std::vector id_data; + std::vector score_data; + + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + if (sort_by_score) { + sort(sentence_vector_list[src_idx].begin(), + sentence_vector_list[src_idx].end(), + [reverse](const Sentence& a, const Sentence& b) { + if (reverse) + return a.scores.front() > b.scores.front(); + else + return a.scores.back() > b.scores.back(); + }); + } + for (Sentence& sentence : sentence_vector_list[src_idx]) { + if (reverse) { + id_data.insert(id_data.end(), sentence.word_ids.rbegin(), + sentence.word_ids.rend()); + score_data.insert(score_data.end(), sentence.scores.rbegin(), + sentence.scores.rend()); + } else { + id_data.insert(id_data.end(), sentence.word_ids.begin(), + sentence.word_ids.end()); + score_data.insert(score_data.end(), sentence.scores.begin(), + sentence.scores.end()); + } + + sentence_level_lod.push_back(sentence_level_lod.back() + + sentence.word_ids.size()); + } + source_level_lod.push_back(source_level_lod.back() + + sentence_vector_list[src_idx].size()); + } + + framework::LoD lod; + lod.push_back(source_level_lod); + lod.push_back(sentence_level_lod); + + id_tensor->set_lod(lod); + id_tensor->Resize({static_cast(id_data.size())}); + id_tensor->mutable_data(); + framework::TensorFromVector(id_data, id_tensor); + + score_tensor->set_lod(lod); + score_tensor->Resize({static_cast(score_data.size())}); + score_tensor->mutable_data(); + framework::TensorFromVector(score_data, score_tensor); +} + +template +void BeamSearchDecoder::Backtrace(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, + LoDTensor* score_tensor) const { + PADDLE_MOBILE_ENFORCE(!step_ids.empty(), "step num should be larger than 0"); + PADDLE_MOBILE_ENFORCE(step_ids.size() == step_scores.size(), + "step_ids and step_scores should be the same"); + const size_t step_num = step_ids.size(); + const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1; + std::vector> sentence_vector_list( + src_num, SentenceVector(beam_size_)); + std::vector> prefix_idx_vector_list(src_num); + for (int step_id = step_num - 1; step_id >= 0; --step_id) { + auto& cur_ids = step_ids.at(step_id); + auto& cur_scores = step_scores.at(step_id); + for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { + // for each source sentence + auto& sentence_vector = sentence_vector_list.at(src_idx); + auto& prefix_idx_vector = prefix_idx_vector_list.at(src_idx); + size_t src_prefix_start = cur_ids.lod().at(kSourceLevel)[src_idx]; + size_t src_prefix_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1]; + if (prefix_idx_vector.empty()) { // be finished and pruned at this step + // or the last time step + for (size_t prefix_idx = src_prefix_start; prefix_idx < src_prefix_end; + ++prefix_idx) { + size_t candidate_start = cur_ids.lod().at(kSentenceLevel)[prefix_idx]; + size_t candidate_end = + cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1]; + for (size_t candidate_idx = candidate_start; + candidate_idx < candidate_end; ++candidate_idx) { + prefix_idx_vector.push_back(prefix_idx); + size_t idx = prefix_idx_vector.size() - 1; + auto cur_id = cur_ids.data()[candidate_idx]; + auto cur_score = cur_scores.data()[candidate_idx]; + sentence_vector.at(idx).word_ids.push_back(cur_id); + sentence_vector.at(idx).scores.push_back(cur_score); + } + } + } else { // use prefix_idx_vector to backtrace + size_t src_candidate_start = + cur_ids.lod().at(kSentenceLevel)[src_prefix_start]; + size_t prefix_idx = src_prefix_start; + size_t candidate_num = + cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] - + cur_ids.lod().at(kSentenceLevel)[prefix_idx]; + for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) { + auto candidate_idx = prefix_idx_vector.at(idx); + auto cur_id = cur_ids.data()[candidate_idx]; + auto cur_score = cur_scores.data()[candidate_idx]; + if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) { + // to skip redundant end tokens + sentence_vector.at(idx).word_ids.push_back(cur_id); + sentence_vector.at(idx).scores.push_back(cur_score); + } + + while (src_candidate_start + candidate_num <= + candidate_idx) { // search the corresponding prefix + prefix_idx++; + candidate_num += cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] - + cur_ids.lod().at(kSentenceLevel)[prefix_idx]; + } + prefix_idx_vector.at(idx) = prefix_idx; + } + } + } + } + + ConvertSentenceVectorToLodTensor(sentence_vector_list, id_tensor, + score_tensor, true, true); +} + +struct BeamSearchDecodeFunctor { + BeamSearchDecodeFunctor(const LoDTensorArray& step_ids, + const LoDTensorArray& step_scores, + LoDTensor* id_tensor, LoDTensor* score_tensor, + size_t beam_size, int end_id) + : beam_size_(beam_size), + end_id_(end_id), + step_ids_(step_ids), + step_scores_(step_scores), + id_tensor_(id_tensor), + score_tensor_(score_tensor) {} + + template + void apply() const; + + size_t beam_size_; + int end_id_; + const LoDTensorArray& step_ids_; + const LoDTensorArray& step_scores_; + LoDTensor* id_tensor_; + LoDTensor* score_tensor_; +}; + +template +void BeamSearchDecodeFunctor::apply() const { + BeamSearchDecoder beam_search_decoder(beam_size_, end_id_); + beam_search_decoder.Backtrace(step_ids_, step_scores_, id_tensor_, + score_tensor_); +} + +template <> +void BeamSearchDecodeFunctor::apply() const { + PADDLE_MOBILE_THROW_EXCEPTION("beam search decode op does not support bool."); +} + +template <> +bool BeamSearchDecodeKernel::Init( + BeamSearchDecodeParam* param) { + return true; +} + +template <> +void BeamSearchDecodeKernel::Compute( + const BeamSearchDecodeParam& param) { + const LoDTensorArray* ids = param.ids_; + const LoDTensorArray* scores = param.scores_; + + const size_t step_num = ids->size(); + PADDLE_MOBILE_ENFORCE(step_num > 0, + "beam search steps should be larger than 0"); + + for (size_t i = 0; i < step_num; ++i) { + PADDLE_MOBILE_ENFORCE(ids->at(i).lod().size() == 2, + "Level of LodTensor should be 2"); + } + const size_t source_num = ids->at(0).lod().at(0).size() - 1; + PADDLE_MOBILE_ENFORCE(source_num > 0, "source num should be larger than 0"); + + LoDTensor* sentence_ids = param.sentence_ids_; + LoDTensor* sentence_scores = param.sentence_scores_; + + framework::VisitDataType( + framework::ToDataType(scores->at(0).type()), + BeamSearchDecodeFunctor(*ids, *scores, sentence_ids, sentence_scores, + param.beam_size_, param.end_id_)); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/beam_search_kernel.cpp b/src/operators/kernel/arm/beam_search_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e88e2f18eed1d9aefbeb954a02245ff6daae036 --- /dev/null +++ b/src/operators/kernel/arm/beam_search_kernel.cpp @@ -0,0 +1,261 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_OP + +#include "operators/kernel/beam_search_kernel.h" +#include + +namespace paddle_mobile { +namespace operators { + +template +class BeamSearchFunctor { + public: + void operator()(const framework::LoDTensor *pre_ids, + const framework::LoDTensor *pre_scores, + const framework::LoDTensor *ids, + const framework::LoDTensor *scores, + framework::LoDTensor *selected_ids, + framework::LoDTensor *selected_scores, + framework::Tensor *parent_idx, size_t level, size_t beam_size, + int end_id, bool is_accumulated) { + auto abs_lod = framework::ToAbsOffset(scores->lod()); + auto &high_level = abs_lod[level]; + + auto items = SelectTopBeamSizeItems(pre_ids, pre_scores, ids, scores, level, + beam_size, end_id, is_accumulated); + auto selected_items = ToMap(items, high_level.back()); + + PruneEndBeams(pre_ids, abs_lod, &selected_items, level, end_id); + // calculate the output tensor's height + size_t num_instances = std::accumulate( + std::begin(selected_items), std::end(selected_items), 0, + [](size_t a, std::vector &b) { return a + b.size(); }); + // the output tensor shape should be [num_instances, 1] + auto dims = framework::make_ddim( + std::vector({static_cast(num_instances), 1})); + selected_ids->Resize(dims); + selected_scores->Resize(dims); + parent_idx->Resize({static_cast(num_instances)}); + + auto *selected_ids_data = selected_ids->mutable_data(); + auto *selected_scores_data = selected_scores->mutable_data(); + auto *parent_idx_data = parent_idx->mutable_data(); + + // fill in data + std::vector low_level; + size_t low_offset = 0; + for (auto &items : selected_items) { + low_level.push_back(low_offset); + for (auto &item : items) { + parent_idx_data[low_offset] = static_cast(low_level.size() - 1); + selected_ids_data[low_offset] = item.id; + selected_scores_data[low_offset] = item.score; + low_offset++; + } + } + low_level.push_back(low_offset); + + // fill lod + framework::LoD lod(2); + lod[0].assign(high_level.begin(), high_level.end()); + lod[1].assign(low_level.begin(), low_level.end()); + selected_ids->set_lod(lod); + selected_scores->set_lod(lod); + } + + /* + * The basic items help to sort. + */ + struct Item { + Item() {} + Item(size_t offset, size_t id, float score) + : offset(offset), id(id), score(score) {} + // offset in the higher lod level. + size_t offset; + // prefix id in the lower lod level. + // size_t prefix; + // the candidate id + size_t id; + // the corresponding score + float score; + + inline bool operator<(const Item &in) const { + return (score < in.score) || + ((score == in.score) && (offset < in.offset)); + } + + inline void operator=(const Item &in) { + offset = in.offset; + id = in.id; + score = in.score; + } + }; + + protected: + /* + * Prune the source sentences all branchs finished, and it is optional. + * Pruning must one step later than finishing (thus pre_ids is needed here), + * since the end tokens must be writed out. + */ + void PruneEndBeams(const framework::LoDTensor *pre_ids, + const framework::LoD &abs_lod, + std::vector> *items, size_t lod_level, + int end_id) { + auto *pre_ids_data = pre_ids->data(); + auto &high_level = abs_lod[lod_level]; + for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) { + size_t src_prefix_start = high_level[src_idx]; + size_t src_prefix_end = high_level[src_idx + 1]; + bool finish_flag = true; + for (size_t offset = src_prefix_start; offset < src_prefix_end; + offset++) { + for (auto &item : items->at(offset)) { + if (item.id != static_cast(end_id) || + pre_ids_data[offset] != end_id) { + finish_flag = false; + break; + } + } + if (!finish_flag) break; + } + if (finish_flag) { // all branchs of the beam (source sentence) end and + // prune this beam + for (size_t offset = src_prefix_start; offset < src_prefix_end; + offset++) + items->at(offset).clear(); + } + } + } + + /* + * Transform the items into a map whose key is offset, value is the items. + * NOTE low performance. + */ + std::vector> ToMap( + const std::vector> &items, size_t element_num) { + std::vector> result; + result.resize(element_num); + for (auto &entries : items) { + for (const auto &item : entries) { + result[item.offset].push_back(item); + } + } + return result; + } + + void Insert(std::vector *top_beam_ptr, const Item &item, + size_t beam_size) { + std::vector &top_beam = *top_beam_ptr; + + size_t num_beams = top_beam.size(); + if (num_beams < beam_size) { + top_beam.resize(num_beams + 1); + num_beams++; + } else { + if (item < top_beam[beam_size - 1]) { + return; + } + } + + for (int k = static_cast(num_beams) - 2; k >= 0; --k) { + if (top_beam[k] < item) { + top_beam[k + 1] = top_beam[k]; + } else { + top_beam[k + 1] = item; + return; + } + } + top_beam[0] = item; + } + + /* + * For each source, select top beam_size records. + */ + std::vector> SelectTopBeamSizeItems( + const framework::LoDTensor *pre_ids, + const framework::LoDTensor *pre_scores, const framework::LoDTensor *ids, + const framework::LoDTensor *scores, size_t lod_level, size_t beam_size, + int end_id, bool is_accumulated) { + std::vector> result; + + // find the current candidates + auto abs_lod = framework::ToAbsOffset(scores->lod()); + + auto *pre_ids_data = pre_ids->data(); + auto *pre_scores_data = pre_scores->data(); + + auto *ids_data = ids ? ids->data() : nullptr; + auto *scores_data = scores->data(); + + size_t num_seqs = scores->NumElements(lod_level); + size_t seq_width = 1; + for (int i = 1; i < scores->dims().size(); i++) { + seq_width *= scores->dims()[i]; + } + + for (size_t seq_id = 0; seq_id < num_seqs; ++seq_id) { + size_t seq_offset_start = abs_lod[lod_level][seq_id]; + size_t seq_offset_end = abs_lod[lod_level][seq_id + 1]; + + std::vector top_beam; + top_beam.reserve(beam_size); + + for (size_t offset = seq_offset_start; offset < seq_offset_end; + ++offset) { + auto pre_id = pre_ids_data[offset]; + auto pre_score = pre_scores_data[offset]; + if (pre_id == end_id) { + // Allocate all probability mass to end_id for finished branchs and + // the other candidate ids can be ignored. + Item item(offset, end_id, pre_score); + Insert(&top_beam, item, beam_size); + } else { + size_t index = offset * seq_width; + for (size_t d = 0; d < seq_width; d++, index++) { + int64_t id = ids_data ? ids_data[index] : static_cast(d); + float score = is_accumulated + ? scores_data[index] + : pre_score + std::log(scores_data[index]); + Item item(offset, id, score); + Insert(&top_beam, item, beam_size); + } + } + } + + result.emplace_back(top_beam); + } + + return result; + } +}; + +template <> +bool BeamSearchKernel::Init(BeamSearchParam *param) { + return true; +} + +template <> +void BeamSearchKernel::Compute(const BeamSearchParam ¶m) { + BeamSearchFunctor alg; + alg(param.pre_ids_, param.pre_scores_, param.ids_, param.scores_, + param.selected_ids_, param.selected_scores_, param.parent_idx_, + param.level_, param.beam_size_, param.end_id_, param.is_accumulated_); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_kernel.cpp b/src/operators/kernel/arm/conv_kernel.cpp deleted file mode 100644 index de19127e68361bd51f25a15c4c7ab69639707433..0000000000000000000000000000000000000000 --- a/src/operators/kernel/arm/conv_kernel.cpp +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef CONV_OP - -#include "operators/kernel/conv_kernel.h" -#include "operators/kernel/central-arm-func/conv_arm_func.h" - -namespace paddle_mobile { -namespace operators { - -template <> -bool ConvKernel::Init(ConvParam *param) { - bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 3; - bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && - param->Filter()->dims()[2] == 5; - bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1]; - bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && - param->Input()->dims()[1] == param->Output()->dims()[1]; - if (param->Filter()->type() == typeid(int8_t)) { -#ifndef __aarch64__ - if (depth3x3 && param->Strides()[0] < 3 && - param->Strides()[0] == param->Strides()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_INT8; - } else if (depth5x5 && param->Strides()[0] < 2 && - param->Strides()[0] == param->Strides()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_INT8; - } else { -#endif // __aarch64__ - param->ExecMode() = ConvParam::EXEC_GEMM_INT8; -#ifndef __aarch64__ - } -#endif // __aarch64__ - } else { - if (depth3x3 && param->Strides()[0] == param->Strides()[1] && - param->Strides()[0] == 1 && param->Paddings()[0] == 1 && - param->Paddings()[0] == param->Paddings()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT; - } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && - param->Strides()[0] == 2 && param->Paddings()[0] == 0 && - param->Paddings()[0] == param->Paddings()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT; - } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && - param->Strides()[0] == 2 && param->Paddings()[0] == 1 && - param->Paddings()[0] == param->Paddings()[1]) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT; - } else if (depth3x3) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_FLOAT; -#ifndef __aarch64__ - } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && - param->Strides()[0] == 1) { - param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; - } else if (conv3x3 && param->Strides()[0] == param->Strides()[1] && - param->Dilations()[0] == param->Dilations()[1] && - param->Strides()[0] == 1 && param->Dilations()[0] == 1 && - param->Output()->dims()[1] >= 16 && - param->Input()->dims()[1] >= 16 && - param->Input()->dims()[2] <= 140 /* refered from ncnn */) { - param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; - // transform weight - param->transformed_filter_ = new framework::Tensor; - operators::math::winograd_transform_weight<8, 3>( - *param->Filter(), param->transformed_filter_); -#endif - } else { - param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; - } - } - return true; -} - -template <> -void ConvKernel::Compute(const ConvParam ¶m) { - switch (param.ExecMode()) { - case ConvParam::EXEC_GEMM_INT8: - GemmConv(param); - break; -#ifndef __aarch64__ - case ConvParam::EXEC_DEPTHWISE3x3_INT8: - DepthwiseConv3x3(param); - break; - case ConvParam::EXEC_DEPTHWISE5x5_INT8: - DepthwiseConv5x5(param); - break; -#endif // __aarch64__ - case ConvParam::EXEC_DEPTHWISE3x3S1P1_FLOAT: - math::DepthwiseConv3x3s1p1(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P1_FLOAT: - math::DepthwiseConv3x3s2p1v2(param.Input(), param.Filter(), - param.Output(), nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3S2P0_FLOAT: - math::DepthwiseConv3x3s2p0(param.Input(), param.Filter(), param.Output(), - nullptr, false, false); - break; - case ConvParam::EXEC_DEPTHWISE3x3_FLOAT: - math::DepthwiseConv3x3(param.Input(), param.Strides(), param.Paddings(), - param.Filter(), nullptr, param.Output(), false); - break; -#ifndef __aarch64__ - case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: - DepthwiseConv5x5(param); - break; - case ConvParam::EXEC_WINOGRAD3X3_FLOAT: - WinogradConv3x3<8, 3>(param); - break; -#endif // __aarch64__ - case ConvParam::EXEC_GEMM_FLOAT: - GemmConv(param); - break; - default: - PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", - param.ExecMode()); - } -} - -template class ConvKernel; - -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp similarity index 65% rename from src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp index 6489140fdd7d01627c51200f2b839866bd911d69..9c70d1e2c899567f523f98fc87963e73ab3fa6a1 100644 --- a/src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_add_bn_relu_kernel.cpp @@ -16,7 +16,9 @@ limitations under the License. */ #include "operators/kernel/conv_add_bn_relu_kernel.h" #include -#include "operators/kernel/central-arm-func/conv_add_bn_relu_arm_func.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include "operators/math/channel_wise.h" namespace paddle_mobile { namespace operators { @@ -43,9 +45,9 @@ bool ConvAddBNReluKernel::Init( } // Tensor *new_scale = new Tensor(); // Tensor *new_bias = new Tensor(); + auto *new_scale = param->CreateNewScale(); + auto *new_bias = param->CreateNewBiase(); - Tensor *new_scale = param->CreateNewScale(); - Tensor *new_bias = param->CreateNewBiase(); auto new_scale_ptr = new_scale->mutable_data({C}); auto new_bias_ptr = new_bias->mutable_data({C}); for (int i = 0; i < C; i++) { @@ -54,14 +56,36 @@ bool ConvAddBNReluKernel::Init( } param->SetNewScale(new_scale); param->SetNewBias(new_bias); + + InitBaseConvKernel(param); return true; } template <> void ConvAddBNReluKernel::Compute( const FusionConvAddBNReluParam ¶m) { - ConvAddBNReluCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); } + template class ConvAddBNReluKernel; } // namespace operators diff --git a/src/operators/kernel/arm/convolution/conv_add_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a44b083a37b19637c053655e23196385d432971 --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_add_kernel.cpp @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADD_OP + +#include "operators/kernel/conv_add_kernel.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include "operators/math/channel_wise.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvAddKernel::Init(FusionConvAddParam *param) { + InitBaseConvKernel(param); + return true; +} + +template <> +void ConvAddKernel::Compute(const FusionConvAddParam ¶m) { + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + break; + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + math::DepthwiseConv3x3S2(*param.Input(), *param.Filter(), + param.Paddings(), param.Output()); + break; + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } + math::AddChannelWise(param.Output(), param.Bias(), param.Output()); +} + +template class ConvAddKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a9efae96e94afa24b48ed46214ff1fdd8ec50d83 --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_add_relu_kernel.cpp @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef FUSION_CONVADDRELU_OP + +#include "operators/kernel/conv_add_relu_kernel.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include "operators/math/channel_wise.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { + InitBaseConvKernel(param); + return true; +} + +template <> +void ConvAddReluKernel::Compute( + const FusionConvAddReluParam ¶m) { + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } + math::AddChannelWise(param.Output(), param.Bias(), param.Output()); +} + +template class ConvAddReluKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp similarity index 64% rename from src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp index 17c9fbd315451af7391ddf0b219fc7f6b7ff22e0..0f52df8b18004da8327caa24ffcd0c599c4f0680 100644 --- a/src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_add_relu_kernel.cpp @@ -16,7 +16,9 @@ limitations under the License. */ #include "operators/kernel/conv_bn_add_relu_kernel.h" #include -#include "operators/kernel/central-arm-func/conv_bn_add_relu_arm_func.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include "operators/math/channel_wise.h" namespace paddle_mobile { namespace operators { @@ -41,8 +43,9 @@ bool ConvBNAddReluKernel::Init( inv_std_ptr[i] = 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); } - Tensor *new_scale = new Tensor(); - Tensor *new_bias = new Tensor(); + + auto *new_scale = param->CreateNewScale(); + auto *new_bias = param->CreateNewBiase(); auto new_scale_ptr = new_scale->mutable_data({C}); auto new_bias_ptr = new_bias->mutable_data({C}); for (int i = 0; i < C; i++) { @@ -51,13 +54,34 @@ bool ConvBNAddReluKernel::Init( } param->SetNewScale(new_scale); param->SetNewBias(new_bias); + + InitBaseConvKernel(param); return true; } template <> void ConvBNAddReluKernel::Compute( const FusionConvBNAddReluParam ¶m) { - ConvBNAddReluCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); } template class ConvBNAddReluKernel; diff --git a/src/operators/kernel/arm/conv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp similarity index 64% rename from src/operators/kernel/arm/conv_bn_relu_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp index b178c7c599c80185ccf3eb201d47fef015a36330..1be0c943976dc5e77a7aa867095a923d9d1093ab 100644 --- a/src/operators/kernel/arm/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/conv_bn_relu_kernel.cpp @@ -16,7 +16,9 @@ limitations under the License. */ #include "operators/kernel/conv_bn_relu_kernel.h" #include -#include "operators/kernel/central-arm-func/conv_bn_relu_arm_func.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include "operators/math/channel_wise.h" namespace paddle_mobile { namespace operators { @@ -29,8 +31,6 @@ bool ConvBNReluKernel::Init(FusionConvBNReluParam *param) { const Tensor *bias = param->InputBias(); const float epsilon = param->Epsilon(); - // DLOG << "variance: " << *variance; - auto mean_ptr = mean->data(); auto variance_ptr = variance->data(); auto scale_ptr = scale->data(); @@ -42,24 +42,45 @@ bool ConvBNReluKernel::Init(FusionConvBNReluParam *param) { inv_std_ptr[i] = 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); } - Tensor *new_scale = param->CreateNewScale(); - Tensor *new_bias = param->CreateNewBiase(); + + auto *new_scale = param->CreateNewScale(); + auto *new_bias = param->CreateNewBiase(); auto new_scale_ptr = new_scale->mutable_data({C}); auto new_bias_ptr = new_bias->mutable_data({C}); for (int i = 0; i < C; i++) { new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i]; new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i]; } - param->SetNewScale(new_scale); param->SetNewBias(new_bias); + + InitBaseConvKernel(param); return true; } template <> void ConvBNReluKernel::Compute( const FusionConvBNReluParam ¶m) { - ConvBNReluCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); } template class ConvBNReluKernel; diff --git a/src/operators/kernel/arm/convolution/conv_common.cpp b/src/operators/kernel/arm/convolution/conv_common.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8db3b36cf43b08c65091035370728db0975709ac --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_common.cpp @@ -0,0 +1,77 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/math/winograd/winograd_transform.h" + +namespace paddle_mobile { +namespace operators { + +void InitBaseConvKernel(ConvParam *param) { + bool conv3x3 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 3; + bool conv5x5 = param->Filter()->dims()[2] == param->Filter()->dims()[3] && + param->Filter()->dims()[2] == 5; + bool depth3x3 = conv3x3 && param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1]; + + bool depth5x5 = conv5x5 && param->Groups() == param->Input()->dims()[1] && + param->Input()->dims()[1] == param->Output()->dims()[1]; + if (param->Filter()->type() == typeid(int8_t)) { +#ifndef __aarch64__ + if (depth3x3 && param->Strides()[0] < 3 && + param->Strides()[0] == param->Strides()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3_INT8; + } else if (depth5x5 && param->Strides()[0] < 2 && + param->Strides()[0] == param->Strides()[1]) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_INT8; + } else { +#endif // __aarch64__ + param->ExecMode() = ConvParam::EXEC_GEMM_INT8; +#ifndef __aarch64__ + } +#endif // __aarch64__ + } else { + if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 1) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT; + } else if (depth3x3 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 2) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT; + } else if (depth5x5 && param->Strides()[0] == param->Strides()[1] && + param->Strides()[0] == 1) { + param->ExecMode() = ConvParam::EXEC_DEPTHWISE5x5_FLOAT; + } else if (conv3x3 && !depth3x3 && + param->Strides()[0] == param->Strides()[1] && + param->Dilations()[0] == param->Dilations()[1] && + param->Strides()[0] == 1 && param->Dilations()[0] == 1 +#if 0 + && param->Output()->dims()[1] >= 16 && + param->Input()->dims()[1] >= 16 && + param->Input()->dims()[2] <= 140 */ /* refered from ncnn */ +#endif + ) { + param->ExecMode() = ConvParam::EXEC_WINOGRAD3X3_FLOAT; + // transform weight + param->transformed_filter_ = new framework::LoDTensor; + operators::math::winograd_transform_weight<8, 3>( + *param->Filter(), param->transformed_filter_); + } else { + param->ExecMode() = ConvParam::EXEC_GEMM_FLOAT; + } + } +} + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/convolution/conv_common.h b/src/operators/kernel/arm/convolution/conv_common.h new file mode 100644 index 0000000000000000000000000000000000000000..4db37715c4302439fa0e43446bd62ef68675276e --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_common.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +void InitBaseConvKernel(ConvParam *param); + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/convolution/conv_kernel.cpp b/src/operators/kernel/arm/convolution/conv_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a819aa50216f06387e24d864fed07674f621b9eb --- /dev/null +++ b/src/operators/kernel/arm/convolution/conv_kernel.cpp @@ -0,0 +1,68 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CONV_OP + +#include "operators/kernel/conv_kernel.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool ConvKernel::Init(ConvParam *param) { + InitBaseConvKernel(param); + return true; +} + +template <> +void ConvKernel::Compute(const ConvParam ¶m) { + switch (param.ExecMode()) { +#ifndef __aarch64__ + case ConvParam::EXEC_GEMM_INT8: + GemmConv(param); + break; + case ConvParam::EXEC_DEPTHWISE3x3_INT8: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_INT8: + DepthwiseConv5x5(param); + break; +#endif // __aarch64__ + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } +} + +template class ConvKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_transpose_kernel.cpp b/src/operators/kernel/arm/convolution/conv_transpose_kernel.cpp similarity index 100% rename from src/operators/kernel/arm/conv_transpose_kernel.cpp rename to src/operators/kernel/arm/convolution/conv_transpose_kernel.cpp diff --git a/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp similarity index 65% rename from src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp rename to src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp index 60d35302942ec398597c5e02157581c19fe0547c..063d51330eb5cc03c5596a8e209480ba0505009f 100644 --- a/src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp +++ b/src/operators/kernel/arm/convolution/dwconv_bn_relu_kernel.cpp @@ -16,7 +16,9 @@ limitations under the License. */ #include "operators/kernel/dwconv_bn_relu_kernel.h" #include -#include "operators/kernel/central-arm-func/dwconv_bn_relu_arm_func.h" +#include "operators/kernel/arm/convolution/conv_common.h" +#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include "operators/math/channel_wise.h" namespace paddle_mobile { namespace operators { @@ -40,8 +42,8 @@ bool DWConvBNReluKernel::Init(FusionDWConvBNReluParam *param) { inv_std_ptr[i] = 1 / static_cast(pow((variance_ptr[i] + epsilon), 0.5)); } - Tensor *new_scale = new Tensor(); - Tensor *new_bias = new Tensor(); + LoDTensor *new_scale = new LoDTensor(); + LoDTensor *new_bias = new LoDTensor(); auto new_scale_ptr = new_scale->mutable_data({C}); auto new_bias_ptr = new_bias->mutable_data({C}); for (int i = 0; i < C; i++) { @@ -50,14 +52,36 @@ bool DWConvBNReluKernel::Init(FusionDWConvBNReluParam *param) { } param->SetNewScale(new_scale); param->SetNewBias(new_bias); + + InitBaseConvKernel(param); return true; } template <> void DWConvBNReluKernel::Compute( const FusionDWConvBNReluParam ¶m) { - DWConvBNReluCompute(param); + switch (param.ExecMode()) { + case ConvParam::EXEC_DEPTHWISE3x3S1_FLOAT: + case ConvParam::EXEC_DEPTHWISE3x3S2_FLOAT: + DepthwiseConv3x3(param); + break; + case ConvParam::EXEC_DEPTHWISE5x5_FLOAT: + DepthwiseConv5x5(param); + break; + case ConvParam::EXEC_WINOGRAD3X3_FLOAT: + WinogradConv3x3<8, 3>(param); + break; + case ConvParam::EXEC_GEMM_FLOAT: + GemmConv(param); + break; + default: + PADDLE_MOBILE_THROW_EXCEPTION("Invalid convolution execute mode %d", + param.ExecMode()); + } + math::ScaleAddChannelWise(param.Output(), param.NewScale(), + param.NewBias(), param.Output()); } + template class DWConvBNReluKernel; } // namespace operators diff --git a/src/operators/kernel/arm/feed_kernel.cpp b/src/operators/kernel/arm/feed_kernel.cpp index 598f6df01b16683f4d6e06f6418a2930a7ec8736..26ea2ac5f7d806aa6e69dfe9697ed84b61347c0e 100644 --- a/src/operators/kernel/arm/feed_kernel.cpp +++ b/src/operators/kernel/arm/feed_kernel.cpp @@ -24,8 +24,9 @@ bool FeedKernel::Init(FeedParam *param) { template <> void FeedKernel::Compute(const FeedParam ¶m) { - param.Out()->ShareDataWith(*(param.InputX())); - param.Out()->set_lod(param.InputX()->lod()); + int col = param.Col(); + param.Out()->ShareDataWith(param.InputX()->at(col)); + param.Out()->set_lod(param.InputX()->at(col).lod()); } template class FeedKernel; diff --git a/src/operators/kernel/arm/fetch_kernel.cpp b/src/operators/kernel/arm/fetch_kernel.cpp index 6c25514857dee9029afa3a7a80d5c89a97bbe9be..8a97fa934b45bf93d7cc727cadbea7cf5ab310f1 100644 --- a/src/operators/kernel/arm/fetch_kernel.cpp +++ b/src/operators/kernel/arm/fetch_kernel.cpp @@ -8,17 +8,24 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "operators/kernel/fetch_kernel.h" + namespace paddle_mobile { namespace operators { + template <> bool FetchKernel::Init(FetchParam *param) { return true; } + template <> void FetchKernel::Compute(const FetchParam ¶m) { - param.Out()->ShareDataWith(*(param.InputX())); + int col = param.Col(); + param.Out()->at(col).ShareDataWith(*(param.InputX())); } + template class FetchKernel; + } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/arm/one_hot_kernel.cpp b/src/operators/kernel/arm/one_hot_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..208b34ea2cd7a2357870c08d27fdcfd164380d0c --- /dev/null +++ b/src/operators/kernel/arm/one_hot_kernel.cpp @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ONE_HOT_OP + +#include "operators/kernel/one_hot_kernel.h" +#include "framework/data_type.h" + +namespace paddle_mobile { +namespace operators { + +template +struct OnehotOpFunctor { + const framework::LoDTensor* in_; + framework::LoDTensor* out_; + int depth_; + + OnehotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, + int depth) + : in_(in), out_(out), depth_(depth) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = out_->mutable_data(); + memset(p_out_data, 0, out_->numel() * sizeof(OutT)); + + for (int i = 0; i < numel; ++i) { + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } + } +}; + +template <> +bool OnehotKernel::Init(OnehotParam* param) { + return true; +} + +template <> +void OnehotKernel::Compute(const OnehotParam& param) { + framework::VisitDataType( + framework::ToDataType(param.dtype_), + OnehotOpFunctor(param.input_, param.output_, param.depth_)); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // ONE_HOT_OP diff --git a/src/operators/kernel/arm/pad2d_kernel.cpp b/src/operators/kernel/arm/pad2d_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e2505316f8aeddb07d3e631fbb830d3522b01aa --- /dev/null +++ b/src/operators/kernel/arm/pad2d_kernel.cpp @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PAD2D_OP + +#include "operators/kernel/pad2d_kernel.h" +#include "operators/math/pad.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool Pad2DKernel::Init(Pad2DParam *param) { + return true; +} + +template <> +void Pad2DKernel::Compute(const Pad2DParam ¶m) { + const auto *input = param.input_; + auto *output = param.output_; + const auto &paddings = param.paddings_; + // if (param.mode_ == "constant" && param.pad_value_ == 0) { + math::PadFunctor pad; + pad(*input, paddings[0], paddings[1], paddings[2], paddings[3], output); + // } else { + // PADDLE_MOBILE_THROW_EXCEPTION("Pad2D has not been implemented."); + // } + output->set_lod(input->lod()); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // PAD2D_OP diff --git a/src/operators/kernel/arm/sequence_expand_kernel.cpp b/src/operators/kernel/arm/sequence_expand_kernel.cpp index f3fb01eed8fbee0af3caff152e7c46749694a43e..82941ff0d566d00aeab404ef96f9b4550d92bb14 100644 --- a/src/operators/kernel/arm/sequence_expand_kernel.cpp +++ b/src/operators/kernel/arm/sequence_expand_kernel.cpp @@ -100,8 +100,8 @@ class SequenceExpandKernel out_lod.push_back(out_lod.back() + x_seq_len); } } + output->set_lod({out_lod}); } - output->set_lod({out_lod}); SequenceExpandImpl(*input_x, y_lod[ref_level], output); } }; diff --git a/src/operators/kernel/arm/sequence_softmax_kernel.cpp b/src/operators/kernel/arm/sequence_softmax_kernel.cpp index ecbc39c4ccf4592308dc07d994535273d4a636f1..b0df21fac560e67f1e1dfe4b42491ee84a384152 100644 --- a/src/operators/kernel/arm/sequence_softmax_kernel.cpp +++ b/src/operators/kernel/arm/sequence_softmax_kernel.cpp @@ -28,6 +28,7 @@ class SequenceSoftmaxKernel bool Init(SoftmaxParam *param) { return true; } void Compute(const SoftmaxParam ¶m) { + param.Out()->mutable_data(); const framework::LoDTensor *input = param.InputX(); framework::LoDTensor *output = param.Out(); math::SequenceSoftmaxFuntor sequence_softmax; diff --git a/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp b/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp index 5d6b86c34c963808e72a05db08940db97b0212b9..72fbb4cadbac4ef824692cfcaf55dd2d5c5a1166 100644 --- a/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp +++ b/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp @@ -28,8 +28,9 @@ void WriteToArrayKernel::Compute( const WriteToArrayParam ¶m) { int64_t offset = param.index_->data()[0]; if (offset >= param.output_->size()) { - param.output_->resize(offset); + param.output_->resize(offset + 1); } + framework::LoDTensor *out_tensor = &(param.output_->at(offset)); out_tensor->set_lod(param.input_->lod()); if (param.input_->memory_size() > 0) { @@ -50,6 +51,11 @@ void ReadFromArrayKernel::Compute( int64_t offset = param.index_->data()[0]; if (offset < param.input_->size()) { TensorCopy(param.input_->at(offset), param.output_); + param.output_->set_lod(param.input_->at(offset).lod()); + } else { + PADDLE_MOBILE_THROW_EXCEPTION( + "Can not read tensor which index is `%d` since it only has `%d` inputs", + offset, param.input_->size()); } } #endif // READ_FROM_ARRAY_OP diff --git a/src/operators/kernel/arm/while_kernel.cpp b/src/operators/kernel/arm/while_kernel.cpp index f27a897ffcadbb1ade759f324766ccbfb8dd49d5..63cd150ec977a5acf1fc24b1dd9bae75bf57b580 100644 --- a/src/operators/kernel/arm/while_kernel.cpp +++ b/src/operators/kernel/arm/while_kernel.cpp @@ -12,12 +12,45 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef WHILE_OP + #include "operators/kernel/while_kernel.h" +#include "framework/op_registry.h" +#include "framework/operator.h" namespace paddle_mobile { namespace operators { -#ifdef WHILE_OP +class StepExecutor { + typedef std::shared_ptr> OperatorPtr; + + public: + StepExecutor(const framework::BlockDesc *block, framework::Scope *scope) + : scope_(scope) { + std::vector> ops = block->Ops(); + ops_of_block_.resize(ops.size()); + for (int i = 0; i < ops.size(); ++i) { + std::shared_ptr op_desc = ops[i]; + DLOG << "create op: " << op_desc->Type(); + auto op_handler = framework::OpRegistry::CreateOp( + op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(), + op_desc->GetAttrMap(), scope_); + ops_of_block_[i] = op_handler; + } + } + + void Run() { + for (auto &op_handler : ops_of_block_) { + op_handler->InferShape(); + op_handler->Run(); + } + } + + private: + framework::Scope *scope_; + std::vector ops_of_block_; +}; + template <> bool WhileKernel::Init(WhileParam *param) { return true; @@ -25,9 +58,15 @@ bool WhileKernel::Init(WhileParam *param) { template <> void WhileKernel::Compute(const WhileParam ¶m) { - // TODO(hjchen2) + auto ¤t_scope = param.scope_->NewScope(); + StepExecutor executor(param.sub_block_, ¤t_scope); + while (param.cond_->data()[0]) { + executor.Run(); + } + param.scope_->DeleteScope(¤t_scope); } -#endif // WHILE_OP } // namespace operators } // namespace paddle_mobile + +#endif // WHILE_OP diff --git a/src/operators/kernel/beam_search_decode_kernel.h b/src/operators/kernel/beam_search_decode_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..36cc7f9f2d1b62bc37e0683417f5f5adfe0edfcc --- /dev/null +++ b/src/operators/kernel/beam_search_decode_kernel.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_DECODE_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class BeamSearchDecodeParam : public OpParam { + public: + BeamSearchDecodeParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, Scope *scope) + : OpParam(inputs, outputs, attrs, scope) { + ids_ = + OpParam::GetVarValue("Ids", inputs, *scope); + scores_ = OpParam::GetVarValue("Scores", inputs, + *scope); + sentence_ids_ = OpParam::GetVarValue("SentenceIds", + outputs, *scope); + sentence_scores_ = OpParam::GetVarValue( + "SentenceScores", outputs, *scope); + beam_size_ = OpParam::GetAttr("beam_size", attrs); + end_id_ = OpParam::GetAttr("end_id", attrs); + } + + public: + framework::LoDTensorArray *ids_; + framework::LoDTensorArray *scores_; + framework::LoDTensor *sentence_ids_; + framework::LoDTensor *sentence_scores_; + int beam_size_; + int end_id_; +}; + +DECLARE_KERNEL(BeamSearchDecode, BeamSearchDecodeParam); + +} // namespace operators +} // namespace paddle_mobile + +#endif // BEAM_SEARCH_DECODE_OP diff --git a/src/operators/kernel/beam_search_kernel.h b/src/operators/kernel/beam_search_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..38fe162b249b5045d4113ac3db906840702e7791 --- /dev/null +++ b/src/operators/kernel/beam_search_kernel.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#define GET_VAR_AS_LOD_TENSOR(name, name_dict, scope) \ + OpParam::GetVarValue(name, name_dict, scope) + +template +class BeamSearchParam : public OpParam { + public: + BeamSearchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, Scope *scope) + : OpParam(inputs, outputs, attrs, scope) { + pre_ids_ = GET_VAR_AS_LOD_TENSOR("pre_ids", inputs, *scope); + pre_scores_ = GET_VAR_AS_LOD_TENSOR("pre_scores", inputs, *scope); + ids_ = GET_VAR_AS_LOD_TENSOR("ids", inputs, *scope); + scores_ = GET_VAR_AS_LOD_TENSOR("scores", inputs, *scope); + + selected_ids_ = GET_VAR_AS_LOD_TENSOR("selected_ids", outputs, *scope); + selected_scores_ = + GET_VAR_AS_LOD_TENSOR("selected_scores", outputs, *scope); + if (outputs.count("parent_idx")) { + parent_idx_ = GET_VAR_AS_LOD_TENSOR("parent_idx", outputs, *scope); + } else { + parent_idx_ = new framework::Tensor(); + } + + level_ = OpParam::GetAttr("level", attrs); + beam_size_ = OpParam::GetAttr("beam_size", attrs); + end_id_ = OpParam::GetAttr("end_id", attrs); + if (OpParam::HasAttr("is_accumulated", attrs)) { + is_accumulated_ = OpParam::GetAttr("is_accumulated", attrs); + } + } + + public: + framework::LoDTensor *pre_ids_; + framework::LoDTensor *pre_scores_; + framework::LoDTensor *ids_; + framework::LoDTensor *scores_; + + framework::LoDTensor *selected_ids_; + framework::LoDTensor *selected_scores_; + framework::Tensor *parent_idx_; + + int level_; + int beam_size_; + int end_id_; + bool is_accumulated_ = true; +}; + +DECLARE_KERNEL(BeamSearch, BeamSearchParam); + +} // namespace operators +} // namespace paddle_mobile + +#endif // BEAM_SEARCH_OP diff --git a/src/operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h deleted file mode 100644 index 4c9ca6e3e8ef995e9cce6f565aafece17ac51b10..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDADDPRELU_OP -#pragma once - -#include -#include -#include "operators/math/conv_func.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -template -void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor bias = *param.Bias(); - Tensor bias1 = *param.Bias1(); - Tensor *output = param.Output(); - output->mutable_data(); - float *biase_data = bias.data(); - - int axis = param.Axis(); - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - Tensor aa = *param.InputAlpha(); - float *p = aa.data(); - - std::string mode = param.Mode(); - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - Tensor bias1_batch = bias1.Slice(i, i + 1).Resize(output_matrix_shape); - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - Tensor bias1_slice = bias1_batch.Slice(g * out_step, (g + 1) * out_step); - float *biase_data1 = bias1_slice.data(); - math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice, - p, mode, biase_data, biase_data1); - } - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif // FUSION_CONVADDADDPRELU_OP diff --git a/src/operators/kernel/central-arm-func/conv_add_prelu_arm_func.h b/src/operators/kernel/central-arm-func/conv_add_prelu_arm_func.h deleted file mode 100644 index d11a8442acdd275c95aaa96b2c3e1855e44746e9..0000000000000000000000000000000000000000 --- a/src/operators/kernel/central-arm-func/conv_add_prelu_arm_func.h +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef FUSION_CONVADDPRELU_OP -#pragma once - -#include -#include -#include "operators/math/conv_func.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" - -namespace paddle_mobile { -namespace operators { - -template -void ConvAddPReluCompute(const FusionConvAddPReluParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor bias = *param.Bias(); - Tensor *output = param.Output(); - output->mutable_data(); - float *biase_data = bias.data(); - - int axis = param.Axis(); - int groups = param.Groups(); - std::vector strides = param.Strides(); - std::vector paddings = param.Paddings(); - std::vector dilations = param.Dilations(); - Tensor aa = *param.InputAlpha(); - float *p = aa.data(); - std::string mode = param.Mode(); - const int batch_size = static_cast(input->dims()[0]); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; - - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - col.ShareDataWith(in_slice); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice, - p, mode, biase_data, nullptr); - } - } -} - -} // namespace operators -} // namespace paddle_mobile - -#endif // FUSION_CONVADDPRELU_OP diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.cpp b/src/operators/kernel/central-arm-func/conv_arm_func.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2c3166720652a77d3b628d2e5fd5d227a1a7fc33 --- /dev/null +++ b/src/operators/kernel/central-arm-func/conv_arm_func.cpp @@ -0,0 +1,248 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/kernel/central-arm-func/conv_arm_func.h" +#include +#include "operators/math/depthwise_conv3x3.h" +#include "operators/math/depthwise_conv5x5.h" +#include "operators/math/im2col.h" +#include "operators/math/math_function.h" +#include "operators/math/pad.h" +#include "operators/math/vol2col.h" +#include "operators/math/winograd/winograd_transform.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +int ConvOutputSize(int input_size, int filter_size, int dilation, int padding, + int stride) { + const int dkernel = dilation * (filter_size - 1) + 1; + int output_size = (input_size + 2 * padding - dkernel) / stride + 1; + return output_size; +} + +bool IsExpand(const std::vector &filter_dim, + const std::vector &strides, const std::vector &paddings, + const std::vector &dilations) { + bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; + for (size_t j = 0; j < strides.size(); ++j) { + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); + strides_1 = strides_1 && (strides[j] == 1); + padding_0 = padding_0 && (paddings[j] == 0); + dilation_1 = dilation_1 && (dilations[j] == 1); + } + + return !(filter_1 && strides_1 && padding_0 && dilation_1); +} + +#ifdef PADDLE_MOBILE_CPU +template +void GemmConv(const ConvParam ¶m) { + const Tensor *input = param.Input(); + Tensor filter = *param.Filter(); + Tensor *output = param.Output(); + output->mutable_data(); + + int groups = param.Groups(); + const std::vector strides = param.Strides(); + const std::vector paddings = param.Paddings(); + const std::vector dilations = param.Dilations(); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec(framework::vectorize(output->dims())); + size_t data_dim = filter_shape_vec.size() - 2; + std::vector col_shape_vec(1 + 2 * data_dim); + col_shape_vec[0] = input->dims()[1] / groups; + for (size_t j = 0; j < data_dim; ++j) { + col_shape_vec[j + 1] = filter_shape_vec[j + 2]; + col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_shape_vec)); + + framework::DDim col_matrix_shape = + framework::flatten_to_2d(col_shape, data_dim + 1); + + bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); + Tensor col; + Tensor col_matrix; + if (is_expand) { + col.mutable_data(col_shape); + col_matrix.ShareDataWith(col); + col_matrix.Resize(col_matrix_shape); + } + + framework::DDim input_shape = framework::slice_ddim( + input->dims(), 1, static_cast(input->dims().size())); + + framework::DDim filter_matrix_shape = {filter.dims()[0], + filter.numel() / filter.dims()[0]}; + filter.Resize(filter_matrix_shape); + framework::DDim output_matrix_shape = { + output->dims()[1], + output->numel() / (output->dims()[0] * output->dims()[1])}; + + // convolution operator: im2col(or vol2col) + gemm + int in_step = static_cast(input->dims()[1]) / groups; + int out_step = static_cast(output->dims()[1]) / groups; + + math::Vol2ColFunctor vol2col; + math::Im2ColFunctor im2col; + + const int batch_size = static_cast(input->dims()[0]); + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); + Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); + + for (int g = 0; g < groups; g++) { + Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); + + if (!is_expand) { + // col_matrix.ShareDataWith(in_slice); + col_matrix = in_slice; + col_matrix.Resize(col_matrix_shape); + } else if (data_dim == 2U) { + // im2col + im2col(in_slice, dilations, strides, + std::vector{paddings[0], paddings[1], paddings[0], + paddings[1]}, + &col); + } else if (data_dim == 3U) { + // vol2col + vol2col(in_slice, dilations, strides, paddings, &col); + } + + // gemm + Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); + Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); + + math::MatMul(filter_slice, false, col_matrix, false, + static_cast(1), &out_slice, + static_cast(0), false, + static_cast(nullptr)); + } + } +} + +template +void WinogradConv3x3(const ConvParam ¶m) { + const Tensor *input = param.Input(); + const Tensor *filter = param.transformed_filter_; + Tensor *output = param.Output(); + output->mutable_data(); + int batch_size = input->dims()[0]; + int groups = param.Groups(); + const std::vector &paddings = param.Paddings(); + + auto winograd_pad = [&](int width, int pad) { + int output_tile = tile - kernel + 1; + // int tiles = (width + pad - kernel) / output_tile + 1; + // return (tiles - 1) * output_tile + tile - width; + int pad_width = (width + 2 * pad - kernel) / output_tile * output_tile; + return pad_width + tile - width; + }; + + math::PadFunctor pad; + Tensor input_pad; + framework::Tensor transformed_input; + for (int i = 0; i < batch_size; ++i) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + // int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]); + // int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]); + int pad_bottom = paddings[0]; + int pad_right = paddings[1]; + if (paddings[0] || paddings[1] || pad_bottom || pad_right) { + framework::DDim pad_shape = in_batch.dims(); + pad_shape[2] += paddings[0] + pad_bottom; + pad_shape[3] += paddings[1] + pad_right; + input_pad.mutable_data(pad_shape); + pad(in_batch, paddings[0], pad_bottom, paddings[1], pad_right, + &input_pad); + } else { + input_pad = in_batch; + } + // tile input and transform + math::winograd_transform_input(input_pad, &transformed_input); + // caculate output + math::winograd_transform_output(transformed_input, *filter, + output); + } +} + +template +void DepthwiseConv3x3(const ConvParam ¶m) { + const Tensor *input = param.Input(); + const Tensor *filter = param.Filter(); + const std::vector &paddings = param.Paddings(); + const std::vector &strides = param.Strides(); + const int batch_size = input->dims()[0]; + Tensor *output = param.Output(); + output->mutable_data(); + + if (strides[0] == 1) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + math::DepthwiseConv3x3S1(in_batch, *filter, paddings, + &out_batch); + } + } else if (strides[0] == 2) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + math::DepthwiseConv3x3S2(in_batch, *filter, paddings, + &out_batch); + } + } else { + GemmConv(param); + } +} + +template +void DepthwiseConv5x5(const ConvParam ¶m) { + const Tensor *input = param.Input(); + const Tensor *filter = param.Filter(); + const std::vector &paddings = param.Paddings(); + const std::vector &strides = param.Strides(); + const int batch_size = input->dims()[0]; + Tensor *output = param.Output(); + output->mutable_data(); + + if (strides[0] == 1) { + for (int i = 0; i < batch_size; i++) { + Tensor in_batch = input->Slice(i, i + 1); + Tensor out_batch = output->Slice(i, i + 1); + math::DepthwiseConv5x5S1(in_batch, *filter, paddings, + &out_batch); + } + } else { + GemmConv(param); + } +} + +template void GemmConv(const ConvParam ¶m); +template void WinogradConv3x3<8, 3>(const ConvParam ¶m); +template void DepthwiseConv3x3(const ConvParam ¶m); +template void DepthwiseConv5x5(const ConvParam ¶m); + +#ifndef __aarch64__ +template void GemmConv(const ConvParam ¶m); +template void DepthwiseConv3x3(const ConvParam ¶m); +template void DepthwiseConv5x5(const ConvParam ¶m); +#endif +#endif + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_arm_func.h b/src/operators/kernel/central-arm-func/conv_arm_func.h index 93be71f554d86b554e6b8ba07b2341b675052785..52bcbbb7c6f76e7e68da4c8a10271bb1bac35adf 100644 --- a/src/operators/kernel/central-arm-func/conv_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_arm_func.h @@ -15,200 +15,31 @@ limitations under the License. */ #ifdef CONV_OP #pragma once + #include -#include "operators/math/conv_func.h" -#include "operators/math/depthwise_conv3x3.h" -#include "operators/math/depthwise_conv5x5.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/pad.h" -#include "operators/math/vol2col.h" -#include "operators/math/winograd/winograd_transform.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { -template -inline void GemmConv(const ConvParam ¶m) { - const Tensor *input = param.Input(); - Tensor filter = *param.Filter(); - Tensor *output = param.Output(); - output->mutable_data(); - int groups = param.Groups(); - const std::vector strides = param.Strides(); - const std::vector paddings = param.Paddings(); - const std::vector dilations = param.Dilations(); - - std::vector filter_shape_vec(framework::vectorize(filter.dims())); - std::vector output_shape_vec(framework::vectorize(output->dims())); - size_t data_dim = filter_shape_vec.size() - 2; - std::vector col_shape_vec(1 + 2 * data_dim); - col_shape_vec[0] = input->dims()[1] / groups; - for (size_t j = 0; j < data_dim; ++j) { - col_shape_vec[j + 1] = filter_shape_vec[j + 2]; - col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2]; - } - framework::DDim col_shape(framework::make_ddim(col_shape_vec)); - - framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, data_dim + 1); - - bool is_expand = - math::IsExpand(filter_shape_vec, strides, paddings, dilations); - Tensor col; - Tensor col_matrix; - if (is_expand) { - col.mutable_data(col_shape); - col_matrix.ShareDataWith(col); - col_matrix.Resize(col_matrix_shape); - } - - framework::DDim input_shape = framework::slice_ddim( - input->dims(), 1, static_cast(input->dims().size())); - - framework::DDim filter_matrix_shape = {filter.dims()[0], - filter.numel() / filter.dims()[0]}; - filter.Resize(filter_matrix_shape); - framework::DDim output_matrix_shape = { - output->dims()[1], - output->numel() / (output->dims()[0] * output->dims()[1])}; - - // convolution operator: im2col(or vol2col) + gemm - int in_step = static_cast(input->dims()[1]) / groups; - int out_step = static_cast(output->dims()[1]) / groups; +int ConvOutputSize(int input_size, int filter_size, int dilation, int padding, + int stride); - math::Vol2ColFunctor vol2col; - math::Im2ColFunctor im2col; +bool IsExpand(const std::vector &filter_dim, + const std::vector &strides, const std::vector &paddings, + const std::vector &dilations); - const int batch_size = static_cast(input->dims()[0]); - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1).Resize(input_shape); - Tensor out_batch = output->Slice(i, i + 1).Resize(output_matrix_shape); - - for (int g = 0; g < groups; g++) { - Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - - if (!is_expand) { - // col_matrix.ShareDataWith(in_slice); - col_matrix = in_slice; - col_matrix.Resize(col_matrix_shape); - } else if (data_dim == 2U) { - // im2col - im2col(in_slice, dilations, strides, - std::vector{paddings[0], paddings[1], paddings[0], - paddings[1]}, - &col); - } else if (data_dim == 3U) { - // vol2col - vol2col(in_slice, dilations, strides, paddings, &col); - } - - // gemm - Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); - Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::MatMul(filter_slice, false, col_matrix, false, - static_cast(1), &out_slice, - static_cast(0), false, - static_cast(nullptr)); - } - } -} +template +void GemmConv(const ConvParam ¶m); template -inline void WinogradConv3x3(const ConvParam ¶m) { - const Tensor *input = param.Input(); - const Tensor *filter = param.transformed_filter_; - Tensor *output = param.Output(); - output->mutable_data(); - int batch_size = input->dims()[0]; - int groups = param.Groups(); - const std::vector &paddings = param.Paddings(); +void WinogradConv3x3(const ConvParam ¶m); - auto winograd_pad = [&](int width, int pad) { - int output_tile = tile - kernel + 1; - // int tiles = (width + pad - kernel) / output_tile + 1; - // return (tiles - 1) * output_tile + tile - width; - int pad_width = (width + 2 * pad - kernel) / output_tile * output_tile; - return pad_width + tile - width; - }; - - math::PadFunctor pad; - Tensor input_pad; - framework::Tensor transformed_input; - for (int i = 0; i < batch_size; ++i) { - Tensor in_batch = input->Slice(i, i + 1); - Tensor out_batch = output->Slice(i, i + 1); - // int pad_bottom = winograd_pad(in_batch.dims()[2], paddings[0]); - // int pad_right = winograd_pad(in_batch.dims()[3], paddings[1]); - int pad_bottom = paddings[0]; - int pad_right = paddings[1]; - if (paddings[0] || paddings[1] || pad_bottom || pad_right) { - framework::DDim pad_shape = in_batch.dims(); - pad_shape[2] += paddings[0] + pad_bottom; - pad_shape[3] += paddings[1] + pad_right; - input_pad.mutable_data(pad_shape); - pad(in_batch, paddings[0], pad_bottom, paddings[1], pad_right, - &input_pad); - } else { - input_pad = in_batch; - } - // tile input and transform - math::winograd_transform_input(input_pad, &transformed_input); - // caculate output - math::winograd_transform_output(transformed_input, *filter, - output); - } -} - -#ifndef __aarch64__ template -inline void DepthwiseConv3x3(const ConvParam ¶m) { - const Tensor *input = param.Input(); - const Tensor *filter = param.Filter(); - const std::vector &paddings = param.Paddings(); - const std::vector &strides = param.Strides(); - const int batch_size = input->dims()[0]; - Tensor *output = param.Output(); - output->mutable_data(); - - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1); - Tensor out_batch = output->Slice(i, i + 1); - if (strides[0] == 1) { - math::DepthwiseConv3x3S1(in_batch, *filter, paddings, - &out_batch); - } else if (strides[0] == 2) { - math::DepthwiseConv3x3S2(in_batch, *filter, paddings, - &out_batch); - } else { - GemmConv(param); - } - } -} +void DepthwiseConv3x3(const ConvParam ¶m); template -inline void DepthwiseConv5x5(const ConvParam ¶m) { - const Tensor *input = param.Input(); - const Tensor *filter = param.Filter(); - const std::vector &paddings = param.Paddings(); - const std::vector &strides = param.Strides(); - const int batch_size = input->dims()[0]; - Tensor *output = param.Output(); - output->mutable_data(); - - if (strides[0] == 1) { - for (int i = 0; i < batch_size; i++) { - Tensor in_batch = input->Slice(i, i + 1); - Tensor out_batch = output->Slice(i, i + 1); - math::DepthwiseConv5x5S1(in_batch, *filter, paddings, - &out_batch); - } - } else { - GemmConv(param); - } -} -#endif // __aarch64__ +void DepthwiseConv5x5(const ConvParam ¶m); } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h index 34e9e120ae3a92e67ed4566aeae0ebd938e5dbf6..33ceefadd85e98da76cc90292bf2f066bd3caace 100644 --- a/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h +++ b/src/operators/kernel/central-arm-func/conv_transpose_arm_func.h @@ -99,7 +99,6 @@ void ConvTransposeCompute(const ConvTransposeParam ¶m) { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &out_slice); - } else if (data_dim == 3U) { col2vol(col, dilations, strides, paddings, &out_slice); } diff --git a/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h index 663c65c83a0f5b76e292925ea8cb0994b0f99ad1..cb5bbc91c3b2cede812d28c77e669ddbe46078bf 100644 --- a/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h +++ b/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h @@ -15,6 +15,8 @@ limitations under the License. */ #ifdef ELEMENTWISESUB_OP #pragma once + +#include "framework/data_type.h" #include "operators/math/elementwise_op_function.h" #include "operators/op_param.h" @@ -26,15 +28,33 @@ struct SubFunctor { inline T operator()(T a, T b) const { return a - b; } }; +struct SubOpFunctor { + const framework::Tensor* x_; + const framework::Tensor* y_; + const int axis_; + framework::Tensor* out_; + + SubOpFunctor(const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* out, const int axis) + : x_(x), y_(y), out_(out), axis_(axis) {} + + template + void apply() const { + out_->mutable_data(); + ElementwiseComputeEx, T>(x_, y_, axis_, SubFunctor(), + out_); + } +}; + template -void ElementwiseSubCompute(const ElementwiseSubParam ¶m) { - const Tensor *input_x = param.InputX(); - const Tensor *input_y = param.InputY(); - Tensor *Out = param.Out(); - Out->mutable_data(); +void ElementwiseSubCompute(const ElementwiseSubParam& param) { + const Tensor* input_x = param.InputX(); + const Tensor* input_y = param.InputY(); + Tensor* out = param.Out(); + int axis = param.Axis(); - ElementwiseComputeEx, float>(input_x, input_y, axis, - SubFunctor(), Out); + framework::VisitDataType(framework::ToDataType(input_x->type()), + SubOpFunctor(input_x, input_y, out, axis)); } template class ElementwiseSubKernel; diff --git a/src/operators/kernel/central-arm-func/gru_unit_arm_func.h b/src/operators/kernel/central-arm-func/gru_unit_arm_func.h index 603592505d2a2bcb9d6557e2a2369d5e8625a764..568273e8738acbc5735397127148a507ab8ae26e 100644 --- a/src/operators/kernel/central-arm-func/gru_unit_arm_func.h +++ b/src/operators/kernel/central-arm-func/gru_unit_arm_func.h @@ -27,30 +27,39 @@ namespace operators { template void GruUnitCompute(const GruUnitParam& param) { + // inputs auto* input = param.InputInput(); auto* hidden_prev = param.InputHiddenPrev(); auto* weight = param.InputWeight(); auto* bias = param.InputBias(); + // outputs auto* gate = param.OutGate(); + gate->mutable_data

(); auto* reset_hidden_prev = param.OutResetHiddenPrev(); + reset_hidden_prev->mutable_data

(); auto* hidden = param.OutHidden(); + hidden->mutable_data

(); + // add bias if (bias) { math::RowwiseAdd add_bias; - add_bias(*gate, *bias, gate); + add_bias(*input, *bias, gate); } int batch_size = input->dims()[0]; int frame_size = hidden_prev->dims()[1]; const P* weight_data = weight->data

(); + math::GRUMetaValue

gru_value; gru_value.gate_weight = const_cast(weight_data); gru_value.state_weight = const_cast(weight_data + 2 * frame_size * frame_size); - gru_value.output_value = hidden->data

(); gru_value.prev_out_value = const_cast(hidden_prev->data

()); + + gru_value.output_value = hidden->data

(); gru_value.gate_value = gate->data

(); gru_value.reset_output_value = reset_hidden_prev->data

(); + auto active_node = math::GetActivationType(param.Activation()); auto active_gate = math::GetActivationType(param.GateActivation()); math::GRUUnitFunctor::compute(gru_value, frame_size, batch_size, diff --git a/src/operators/kernel/central-arm-func/increment_arm_func.h b/src/operators/kernel/central-arm-func/increment_arm_func.h index 44465b2a2f10ad0ca9cb2b6166d14429197a1e30..96473fef81da3e29b70270bed8456d408b31f736 100644 --- a/src/operators/kernel/central-arm-func/increment_arm_func.h +++ b/src/operators/kernel/central-arm-func/increment_arm_func.h @@ -25,11 +25,11 @@ template void IncrementCompute(const IncrementParam ¶m) { const framework::Tensor *input = param.InputX(); framework::Tensor *out = param.Out(); - int step = param.Step(); + float step = param.Step(); - out->mutable_data

(); - const P *input_data = input->data

(); - P *out_data = out->data

(); + out->mutable_data(); + const int64_t *input_data = input->data(); + int64_t *out_data = out->data(); *out_data = *input_data + step; } diff --git a/src/operators/kernel/conv_add_bn_kernel.h b/src/operators/kernel/conv_add_bn_kernel.h index 7a921ecc7d0f4498cae80fbb9cea1b13e4c94101..757664eb536f871811964608c8ad709c416d126c 100644 --- a/src/operators/kernel/conv_add_bn_kernel.h +++ b/src/operators/kernel/conv_add_bn_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/conv_add_bn_relu_kernel.h b/src/operators/kernel/conv_add_bn_relu_kernel.h index 3f088528fc901987873038c7e1dd779dcc2019e7..919c66106eda1159f14c40e768325f1f5dcf5ff6 100644 --- a/src/operators/kernel/conv_add_bn_relu_kernel.h +++ b/src/operators/kernel/conv_add_bn_relu_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/conv_add_kernel.h b/src/operators/kernel/conv_add_kernel.h index 140d0475a8ee2f017a7c587c38429ccbb2edd387..fd3f279a7829a5803da6e08c0280435443425ad0 100644 --- a/src/operators/kernel/conv_add_kernel.h +++ b/src/operators/kernel/conv_add_kernel.h @@ -23,7 +23,6 @@ limitations under the License. */ #include "common/common.h" #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/depthwise_conv3x3.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" diff --git a/src/operators/kernel/conv_add_relu_kernel.h b/src/operators/kernel/conv_add_relu_kernel.h index e001926b361da96ec3ff76e120bc3d1ad13714fa..8cfc92ef19937650f1835e16eb26c1bf59f2d345 100644 --- a/src/operators/kernel/conv_add_relu_kernel.h +++ b/src/operators/kernel/conv_add_relu_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/conv_bn_add_relu_kernel.h b/src/operators/kernel/conv_bn_add_relu_kernel.h index dcd8fecf07fbb4ea75b382f5315e24e64e26e939..63a86b56538a259b783a6a99536b6c5be15d915a 100644 --- a/src/operators/kernel/conv_bn_add_relu_kernel.h +++ b/src/operators/kernel/conv_bn_add_relu_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/conv_bn_kernel.h b/src/operators/kernel/conv_bn_kernel.h index e669f3bdd85dbd89e3a48d417dcd0cd6b9706062..1fb0d680cf4584e2433af254cca25bc52a3b9e03 100644 --- a/src/operators/kernel/conv_bn_kernel.h +++ b/src/operators/kernel/conv_bn_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/conv_bn_relu_kernel.h b/src/operators/kernel/conv_bn_relu_kernel.h index 91b3413116ae22a8e212cf149c4e0c2a8924664a..f63b61ab09f90c8c40738cbe94ec6ebcff9420ff 100644 --- a/src/operators/kernel/conv_bn_relu_kernel.h +++ b/src/operators/kernel/conv_bn_relu_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/dwconv_bn_relu_kernel.h b/src/operators/kernel/dwconv_bn_relu_kernel.h index f2e4c0afbd0aaafff5339816764f9e30592f122c..3bd8093adb539d8fc0f6ea4b400b9ff864e1b664 100644 --- a/src/operators/kernel/dwconv_bn_relu_kernel.h +++ b/src/operators/kernel/dwconv_bn_relu_kernel.h @@ -19,7 +19,6 @@ limitations under the License. */ #include #include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" #include "operators/math/im2col.h" #include "operators/math/math_function.h" #include "operators/math/vol2col.h" diff --git a/src/operators/kernel/fpga/V1/conv_add_bn_kernel.cpp b/src/operators/kernel/fpga/V1/conv_add_bn_kernel.cpp index ecebe2fd91d62c29966b7726846c81b78f68ae52..c052805dfdc361965c4fc5068ab386367f087797 100644 --- a/src/operators/kernel/fpga/V1/conv_add_bn_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_add_bn_kernel.cpp @@ -26,11 +26,11 @@ bool ConvAddBNKernel::Init(FusionConvAddBNParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); auto bias = param->Bias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); diff --git a/src/operators/kernel/fpga/V1/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/fpga/V1/conv_add_bn_relu_kernel.cpp index 38d469fa7054193d24faf3d0981de4d87e0d32a5..a7a93de9baed8711a66665ac9510094811ca44d9 100755 --- a/src/operators/kernel/fpga/V1/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_add_bn_relu_kernel.cpp @@ -27,10 +27,10 @@ bool ConvAddBNReluKernel::Init( paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::LEAKYRELU; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); auto bias = param->Bias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); vector paddings = param->Paddings(); diff --git a/src/operators/kernel/fpga/V1/conv_add_kernel.cpp b/src/operators/kernel/fpga/V1/conv_add_kernel.cpp index 153be5a4f888c2a39a7b05b9a7fbb72e305acb8d..da16af58f117b2fbb0e4b6442f9496ea9b824317 100644 --- a/src/operators/kernel/fpga/V1/conv_add_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_add_kernel.cpp @@ -25,10 +25,10 @@ bool ConvAddKernel::Init(FusionConvAddParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); const Tensor *bias = param->Bias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0], diff --git a/src/operators/kernel/fpga/V1/conv_add_relu_kernel.cpp b/src/operators/kernel/fpga/V1/conv_add_relu_kernel.cpp index eef35bf74b6b28e3ec0c49d6b7ace0a350f3f194..f1f61da4217d4ecf3ce12e75b9fba3d3447cb4f6 100644 --- a/src/operators/kernel/fpga/V1/conv_add_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_add_relu_kernel.cpp @@ -25,10 +25,10 @@ bool ConvAddReluKernel::Init(FusionConvAddReluParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::LEAKYRELU; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); const Tensor *bias = param->Bias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0], diff --git a/src/operators/kernel/fpga/V1/conv_bn_kernel.cpp b/src/operators/kernel/fpga/V1/conv_bn_kernel.cpp index 10ea54e380f8d9a585f03427ced1e569f0849b52..54d99f22d185b0252ad4b5b5b48ceaa1e424b1c6 100644 --- a/src/operators/kernel/fpga/V1/conv_bn_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_bn_kernel.cpp @@ -26,8 +26,8 @@ bool ConvBNKernel::Init(FusionConvBNParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); - auto filter = const_cast(param->Filter()); + auto input = const_cast(param->Input()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); auto bn_mean_ptr = param->InputMean()->data(); auto bn_var_ptr = param->InputVariance()->data(); diff --git a/src/operators/kernel/fpga/V1/conv_bn_relu_kernel.cpp b/src/operators/kernel/fpga/V1/conv_bn_relu_kernel.cpp index 5f8f85278e81911d67f1e072b390e6cd74149ee4..4ce8265f7f780d5ea4291783e309cd9507bf18b6 100644 --- a/src/operators/kernel/fpga/V1/conv_bn_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_bn_relu_kernel.cpp @@ -23,8 +23,8 @@ bool ConvBNReluKernel::Init(FusionConvBNReluParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::LEAKYRELU; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); - auto filter = const_cast(param->Filter()); + auto input = const_cast(param->Input()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); auto bn_mean_ptr = param->InputMean()->data(); auto bn_var_ptr = param->InputVariance()->data(); diff --git a/src/operators/kernel/fpga/V1/conv_kernel.cpp b/src/operators/kernel/fpga/V1/conv_kernel.cpp index 73722820bd90b54abd64dd01b157c74c6a1069e8..57b5eb754e327160399bee728d0689101fac1134 100644 --- a/src/operators/kernel/fpga/V1/conv_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_kernel.cpp @@ -24,8 +24,8 @@ bool ConvKernel::Init(ConvParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); - auto filter = const_cast(param->Filter()); + auto input = const_cast(param->Input()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); int channel = out->dims()[1]; auto bs_ptr = diff --git a/src/operators/kernel/fpga/V1/conv_transpose_kernel.cpp b/src/operators/kernel/fpga/V1/conv_transpose_kernel.cpp index 788504df5d2ea1005cfaa76f12b58e61c0218391..1597885e43e01895b6acd425031341af70d5eaf7 100644 --- a/src/operators/kernel/fpga/V1/conv_transpose_kernel.cpp +++ b/src/operators/kernel/fpga/V1/conv_transpose_kernel.cpp @@ -27,10 +27,10 @@ bool ConvTransposeKernel::Init(ConvTransposeParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); // const Tensor *bias = param->Bias(); // auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); // PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0], diff --git a/src/operators/kernel/fpga/V1/deconv_add_bn_kernel.cpp b/src/operators/kernel/fpga/V1/deconv_add_bn_kernel.cpp index 4239ac1e5da421cb0e2421a8919d8d15e40348af..a8205df3c9c1052055ba15ca58fd215f1d49ba0e 100644 --- a/src/operators/kernel/fpga/V1/deconv_add_bn_kernel.cpp +++ b/src/operators/kernel/fpga/V1/deconv_add_bn_kernel.cpp @@ -27,10 +27,10 @@ bool DeconvAddBNKernel::Init(FusionDeconvAddBNParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); const Tensor *bias = param->InputBias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0], diff --git a/src/operators/kernel/fpga/V1/deconv_add_bn_relu_kernel.cpp b/src/operators/kernel/fpga/V1/deconv_add_bn_relu_kernel.cpp index 28b8c83198a5517ed0dc9732e0033030a876a7da..b27f5cf870d2e3220bec31ee63bb27361cb2c8cf 100755 --- a/src/operators/kernel/fpga/V1/deconv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V1/deconv_add_bn_relu_kernel.cpp @@ -28,10 +28,10 @@ bool DeconvAddBNReluKernel::Init( paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::LEAKYRELU; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); const Tensor *bias = param->InputBias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0], diff --git a/src/operators/kernel/fpga/V1/deconv_add_kernel.cpp b/src/operators/kernel/fpga/V1/deconv_add_kernel.cpp index 97a4d5516b52939a3a1d90a22c8050679810d405..41844d008b2c8313fc8f1ac75a00d9864b5a20a5 100644 --- a/src/operators/kernel/fpga/V1/deconv_add_kernel.cpp +++ b/src/operators/kernel/fpga/V1/deconv_add_kernel.cpp @@ -27,10 +27,10 @@ bool DeconvAddKernel::Init(FusionDeconvAddParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); const Tensor *bias = param->Bias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0], diff --git a/src/operators/kernel/fpga/V1/deconv_add_relu_kernel.cpp b/src/operators/kernel/fpga/V1/deconv_add_relu_kernel.cpp index f0b29943d7731d716a19cff1e3cfc904d7610c0b..c6fc9d195511ae3218450fa58393ba420444eb92 100644 --- a/src/operators/kernel/fpga/V1/deconv_add_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V1/deconv_add_relu_kernel.cpp @@ -28,10 +28,10 @@ bool DeconvAddReluKernel::Init( paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::LEAKYRELU; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); const Tensor *bias = param->Bias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); PADDLE_MOBILE_ENFORCE(out->dims()[1] == bias->dims()[0], diff --git a/src/operators/kernel/fpga/V1/deconv_bn_relu_kernel.cpp b/src/operators/kernel/fpga/V1/deconv_bn_relu_kernel.cpp index f166587109e5f63e30203a940aa3baa8ae87f844..75597f0ecd570b6b21894a2f9a0ff0ad91a54ea4 100644 --- a/src/operators/kernel/fpga/V1/deconv_bn_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V1/deconv_bn_relu_kernel.cpp @@ -29,10 +29,10 @@ bool DeconvBNReluKernel::Init( paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::LEAKYRELU; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->Input()); + auto input = const_cast(param->Input()); const Tensor *bias = param->InputBias(); auto bias_ptr = bias->data(); - auto filter = const_cast(param->Filter()); + auto filter = const_cast(param->Filter()); auto out = param->Output(); auto bn_mean_ptr = param->InputMean()->data(); auto bn_var_ptr = param->InputVariance()->data(); diff --git a/src/operators/kernel/fpga/V1/feed_kernel.cpp b/src/operators/kernel/fpga/V1/feed_kernel.cpp index 0e16de63777c92f5bf89e28b3c521f1854a0496d..a661cd642c51a1baff2ac6ec97933831bd034c40 100644 --- a/src/operators/kernel/fpga/V1/feed_kernel.cpp +++ b/src/operators/kernel/fpga/V1/feed_kernel.cpp @@ -20,6 +20,11 @@ namespace operators { template <> bool FeedKernel::Init(FeedParam *param) { auto output = param->Out(); + int col = param->Col(); + auto input = const_cast(¶m->InputX()->at(col)); + input->init(typeid(float)); + input->Resize(output->dims()); + if (output->dims().size() != 4) { return true; } @@ -31,7 +36,8 @@ bool FeedKernel::Init(FeedParam *param) { template <> void FeedKernel::Compute(const FeedParam ¶m) { auto output = param.Out(); - auto input = const_cast(param.InputX()); + int col = param.Col(); + auto input = const_cast(¶m.InputX()->at(col)); std::type_index input_type = input->type(); if (input_type == typeid(float)) { diff --git a/src/operators/kernel/fpga/V1/fetch_kernel.cpp b/src/operators/kernel/fpga/V1/fetch_kernel.cpp index 5010ac7ad41850e41b58a897ae2e969b7831e90d..b128c8e3430b8a359a5ad9dbcba397ad0f2b6568 100644 --- a/src/operators/kernel/fpga/V1/fetch_kernel.cpp +++ b/src/operators/kernel/fpga/V1/fetch_kernel.cpp @@ -17,8 +17,9 @@ namespace operators { template <> bool FetchKernel::Init(FetchParam *param) { - auto input = const_cast(param->InputX()); - auto output = param->Out(); + auto input = const_cast(param->InputX()); + int col = param->Col(); + auto output = &(param->Out()->at(col)); if (input->type() == typeid(float)) { return true; } @@ -56,12 +57,9 @@ void dealign(float *src, float *dst, int input_c, int input_h, int input_w) { } template <> void FetchKernel::Compute(const FetchParam ¶m) { - auto input = const_cast(param.InputX()); - if (input->type() == typeid(float)) { - auto output = param.Out(); - output->ShareDataWith(*input); - return; - } + auto input = const_cast(param.InputX()); + int col = param.Col(); + LoDTensor *out = ¶m.Out()->at(col); fpga::BypassArgs args = param.fpga_bypass_args; auto input_address = (input->data()); @@ -69,7 +67,7 @@ void FetchKernel::Compute(const FetchParam ¶m) { float *outdata_ptr = reinterpret_cast(param.fpga_bypass_args.output.address); const int num_th = 32; - if ((param.Out()->fpga_data_num) < num_th) { + if ((out->fpga_data_num) < num_th) { fpga::fpga_invalidate(input_address, (input->fpga_data_num) * sizeof(half)); for (int idx = 0; idx < product(input->dims()); ++idx) { @@ -79,14 +77,14 @@ void FetchKernel::Compute(const FetchParam ¶m) { } fpga::PerformBypass(args); - auto outC = param.Out()->dims()[1]; - auto outH = param.Out()->dims()[2]; - auto outW = param.Out()->dims()[3]; + auto outC = out->dims()[1]; + auto outH = out->dims()[2]; + auto outW = out->dims()[3]; fpga::fpga_invalidate(param.fpga_bypass_args.output.address, - param.Out()->fpga_data_num * sizeof(float)); + out->fpga_data_num * sizeof(float)); - if (param.Out()->fpga_data_num != product(input->dims())) { + if (out->fpga_data_num != product(input->dims())) { float *data_tmp = reinterpret_cast(malloc(outC * outH * outW * sizeof(float))); dealign(outdata_ptr, data_tmp, outC, outH, outW); diff --git a/src/operators/kernel/fpga/V1/fusion_fc_kernel.cpp b/src/operators/kernel/fpga/V1/fusion_fc_kernel.cpp index 6669ff2ccdcea028000a2e12b82721cc442d9271..3a29104d0fe0e3c69c9369fb1137b2c94ef04e43 100644 --- a/src/operators/kernel/fpga/V1/fusion_fc_kernel.cpp +++ b/src/operators/kernel/fpga/V1/fusion_fc_kernel.cpp @@ -25,7 +25,7 @@ bool FusionFcKernel::Init(FusionFcParam *param) { paddle_mobile::fpga::NONE; int16_t leaky_relu_negative_slope = 0; auto input_x = const_cast(param->InputX()); - auto filter = const_cast(param->InputY()); + auto filter = const_cast(param->InputY()); const Tensor *input_z = param->InputZ(); auto input_z_ptr = input_z->data(); auto out = param->Out(); diff --git a/src/operators/kernel/fpga/V1/fusion_fc_relu_kernel.cpp b/src/operators/kernel/fpga/V1/fusion_fc_relu_kernel.cpp index 6fbeb63fe606aac014f76088210c74a4118e6c78..fef370515e9e9ffa1d90c184e62919235533b8a5 100644 --- a/src/operators/kernel/fpga/V1/fusion_fc_relu_kernel.cpp +++ b/src/operators/kernel/fpga/V1/fusion_fc_relu_kernel.cpp @@ -25,7 +25,7 @@ bool FusionFcReluKernel::Init(FusionFcReluParam *param) { paddle_mobile::fpga::LEAKYRELU; int16_t leaky_relu_negative_slope = 0; auto input_x = const_cast(param->InputX()); - auto filter = const_cast(param->InputY()); + auto filter = const_cast(param->InputY()); const Tensor *input_z = param->InputZ(); auto input_z_ptr = input_z->data(); auto out = param->Out(); diff --git a/src/operators/kernel/fpga/V1/pad2d_kernel.cpp b/src/operators/kernel/fpga/V1/pad2d_kernel.cpp index f47a585ee412316ce65084c5fa10a622ffb93a4f..5d81f71c3608d19f5be5c46699b8379ebb279982 100644 --- a/src/operators/kernel/fpga/V1/pad2d_kernel.cpp +++ b/src/operators/kernel/fpga/V1/pad2d_kernel.cpp @@ -16,8 +16,8 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { template <> -bool Pad2dKernel::Init(Pad2dParam *param) { - Tensor *output = param->Out(); +bool Pad2DKernel::Init(Pad2DParam *param) { + Tensor *output = param->output_; fpga::format_fp16_ofm(output); return true; } @@ -39,9 +39,9 @@ void pad2dFunc(const framework::Tensor *input, framework::Tensor *output) { } } template <> -void Pad2dKernel::Compute(const Pad2dParam ¶m) { - auto in_x = param.InputX(); - auto out = param.Out(); +void Pad2DKernel::Compute(const Pad2DParam ¶m) { + auto in_x = param.input_; + auto out = param.output_; fpga::fpga_invalidate((void *)in_x->data(), // NOLINT in_x->numel() * sizeof(half)); pad2dFunc(in_x, out); diff --git a/src/operators/kernel/fpga/V1/pool_kernel.cpp b/src/operators/kernel/fpga/V1/pool_kernel.cpp index 0bba15be7757ed3170402a47780e40cb94b9cfa0..994fa151621956aa791d36cc0f4cd829dc88f3d1 100644 --- a/src/operators/kernel/fpga/V1/pool_kernel.cpp +++ b/src/operators/kernel/fpga/V1/pool_kernel.cpp @@ -21,7 +21,7 @@ namespace operators { template <> bool PoolKernel::Init(PoolParam *param) { - auto *input = const_cast(param->Input()); + auto *input = const_cast(param->Input()); auto *output = param->Output(); vector ksize = param->Ksize(); vector strides = param->Strides(); @@ -68,7 +68,7 @@ bool PoolKernel::Init(PoolParam *param) { template <> void PoolKernel::Compute(const PoolParam ¶m) { - auto *input = const_cast(param.Input()); + auto *input = const_cast(param.Input()); if (input->type() == typeid(float)) { auto *output = param.Output(); diff --git a/src/operators/kernel/fpga/V1/proposal_kernel.cpp b/src/operators/kernel/fpga/V1/proposal_kernel.cpp index a86e011b61d6dab79159cbb4d34cabab747a1d3b..772c68059ddb85958279639626bfb9e2b36fb91b 100644 --- a/src/operators/kernel/fpga/V1/proposal_kernel.cpp +++ b/src/operators/kernel/fpga/V1/proposal_kernel.cpp @@ -15,11 +15,13 @@ limitations under the License. */ #ifdef PROPOSAL_OP #include +#include #include #include "operators/kernel/detection_kernel.h" namespace paddle_mobile { namespace operators { + static const double kBBoxClipDefault = std::log(1000.0 / 16.0); template <> diff --git a/src/operators/kernel/fpga/V1/sigmoid_kernel.cpp b/src/operators/kernel/fpga/V1/sigmoid_kernel.cpp index bf36873a1fb442a4d5ff6f57056515009d275cd6..bb9eb3d6e8acf3d59ce3c4541f8c553fe7cb1cc2 100644 --- a/src/operators/kernel/fpga/V1/sigmoid_kernel.cpp +++ b/src/operators/kernel/fpga/V1/sigmoid_kernel.cpp @@ -24,7 +24,7 @@ bool SigmoidKernel::Init(SigmoidParam *param) { paddle_mobile::fpga::ActivationType activation_enable = paddle_mobile::fpga::SIGMOID; int16_t leaky_relu_negative_slope = 0; - auto input = const_cast(param->InputX()); + auto input = const_cast(param->InputX()); auto input_ptr = input->data(); auto out = param->Out(); fpga::format_fp16_ofm(out); diff --git a/src/operators/kernel/fpga/V1/softmax_kernel.cpp b/src/operators/kernel/fpga/V1/softmax_kernel.cpp index 116a9594ee45ce862d8d4f58990637a062dfb092..5537565bc2a4dc7563148617daf47eaa9a50ba91 100644 --- a/src/operators/kernel/fpga/V1/softmax_kernel.cpp +++ b/src/operators/kernel/fpga/V1/softmax_kernel.cpp @@ -33,7 +33,7 @@ bool SoftmaxKernel::Init(SoftmaxParam *param) { input_ptr = input->data(); } - auto float_input = new Tensor; + auto float_input = new LoDTensor; PADDLE_MOBILE_ENFORCE(input->dims().size() == 4, "Softmax should have 4-order input"); diff --git a/src/operators/kernel/fpga/V1/split_kernel.cpp b/src/operators/kernel/fpga/V1/split_kernel.cpp index 2aef8d018c3480e396c22ac6f6c953a1387c331d..584cb41fb30b02c757430bd748d4672cc870b591 100644 --- a/src/operators/kernel/fpga/V1/split_kernel.cpp +++ b/src/operators/kernel/fpga/V1/split_kernel.cpp @@ -20,7 +20,7 @@ namespace paddle_mobile { namespace operators { template <> bool SplitKernel::Init(SplitParam *param) { - auto *in = const_cast(param->InputX()); + auto *in = const_cast(param->InputX()); auto outs = param->Outs(); auto sections = param->Sections(); int axis = param->Axis(); diff --git a/src/operators/kernel/fpga/V1/tanh_kernel.cpp b/src/operators/kernel/fpga/V1/tanh_kernel.cpp index 7b5e2153ceb777325bd80c445b96bcfc2d631303..d7bbc5f0435aaca53be01d6c82d919a2df072ce2 100644 --- a/src/operators/kernel/fpga/V1/tanh_kernel.cpp +++ b/src/operators/kernel/fpga/V1/tanh_kernel.cpp @@ -21,10 +21,10 @@ namespace operators { template <> bool TanhKernel::Init(TanhParam *param) { - auto input = const_cast(param->InputX()); + auto input = const_cast(param->InputX()); DLOG << "input: " << input; auto input_ptr = input->data(); - auto float_input = new Tensor; + auto float_input = new LoDTensor; float_input->mutable_data( {1, input->dims()[1], input->dims()[2], input->dims()[3]}); diff --git a/src/operators/kernel/lrn_kernel.h b/src/operators/kernel/lrn_kernel.h index 99dbfe2d658cde17e6399f8ea4bc5b945092cde5..486c828acab6d24741baae5804f09bc3e850b02f 100644 --- a/src/operators/kernel/lrn_kernel.h +++ b/src/operators/kernel/lrn_kernel.h @@ -15,24 +15,21 @@ limitations under the License. */ #pragma once #ifdef LRN_OP + +#include #ifdef _OPENMP #include #endif -#include "framework/operator.h" -#include "operators/op_param.h" - -#include - #ifdef __ARM_NEON -#include "arm_neon.h" -#include "operators/math/math_func_neon.h" +#include +#include "operators/math/math.h" #endif +#include "framework/operator.h" +#include "operators/op_param.h" namespace paddle_mobile { namespace operators { -using namespace framework; - template struct LRNFunctor { void operator()(const framework::Tensor &input, framework::Tensor *out, int N, diff --git a/src/operators/kernel/conv_add_prelu_kernel.h b/src/operators/kernel/one_hot_kernel.h similarity index 50% rename from src/operators/kernel/conv_add_prelu_kernel.h rename to src/operators/kernel/one_hot_kernel.h index 631982789b09c57d0d21186d0a30df7368d2955f..fd883cabee2138df0602980b18c154730f8b2ca9 100644 --- a/src/operators/kernel/conv_add_prelu_kernel.h +++ b/src/operators/kernel/one_hot_kernel.h @@ -12,34 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#pragma once +#ifdef ONE_HOT_OP -#ifdef FUSION_CONVADDPRELU_OP +#pragma once -#include -#include "framework/ddim.h" #include "framework/operator.h" -#include "operators/math/conv_func.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { -using framework::DDim; -using framework::OpKernelBase; +#define GET_VAR_AS_LOD_TENSOR(name, name_dict, scope) \ + OpParam::GetVarValue(name, name_dict, scope) -template -class ConvAddPReluKernel - : public OpKernelBase> { +template +class OnehotParam : public OpParam { public: - void Compute(const FusionConvAddPReluParam ¶m); - bool Init(FusionConvAddPReluParam *param); + OnehotParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, Scope *scope) + : OpParam(inputs, outputs, attrs, scope) { + input_ = GET_VAR_AS_LOD_TENSOR("X", inputs, *scope); + output_ = GET_VAR_AS_LOD_TENSOR("Out", outputs, *scope); + + depth_ = OpParam::GetAttr("depth", attrs); + dtype_ = OpParam::GetAttr("dtype", attrs); + } + + public: + framework::LoDTensor *input_; + framework::LoDTensor *output_; + + int depth_; + int dtype_; }; +DECLARE_KERNEL(Onehot, OnehotParam); + } // namespace operators } // namespace paddle_mobile -#endif +#endif // ONE_HOT_OP diff --git a/src/operators/kernel/pad2d_kernel.h b/src/operators/kernel/pad2d_kernel.h index 58b8c1a15884b00dc0c309c99da7de0706524cdd..0834cbc0cff246c7738405cf44f40b3ecc2c5e70 100644 --- a/src/operators/kernel/pad2d_kernel.h +++ b/src/operators/kernel/pad2d_kernel.h @@ -12,21 +12,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PAD2D_OP + #pragma once +#include +#include #include "framework/operator.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { -template -class Pad2dKernel - : public framework::OpKernelBase> { +template +class Pad2DParam : public OpParam { + public: + Pad2DParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, Scope *scope) + : OpParam(inputs, outputs, attrs, scope) { + input_ = OpParam::GetVarValue("X", inputs, *scope); + output_ = + OpParam::GetVarValue("Out", outputs, *scope); + paddings_ = OpParam::GetAttr>("paddings", attrs); + pad_value_ = OpParam::GetAttr("pad_value", attrs); + mode_ = OpParam::GetStringAttr("mode", attrs); + } + public: - void Compute(const Pad2dParam ¶m); - bool Init(Pad2dParam *param); + framework::LoDTensor *input_; + framework::LoDTensor *output_; + std::vector paddings_; + float pad_value_; + std::string mode_; }; +DECLARE_KERNEL(Pad2D, Pad2DParam); + } // namespace operators } // namespace paddle_mobile + +#endif // PAD2D_OP diff --git a/src/operators/kernel/while_kernel.h b/src/operators/kernel/while_kernel.h index 45b83d79669127ba9d5a65f1de83edf82dde96d8..ba014a9079f6877f726d7697c487653b848f9dbb 100644 --- a/src/operators/kernel/while_kernel.h +++ b/src/operators/kernel/while_kernel.h @@ -26,21 +26,16 @@ class WhileParam : public OpParam { public: WhileParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, Scope *scope) - : inputs_(inputs), - outputs_(outputs), - scope_(*scope), - OpParam(inputs, outputs, attrs, scope) { + : scope_(scope), OpParam(inputs, outputs, attrs, scope) { cond_ = OpParam::GetVarValue("Condition", inputs, *scope); - sub_block_ = OpParam::GetAttr("sub_block", attrs); + sub_block_ = OpParam::GetAttr("sub_block", attrs); } public: + const Scope *scope_; framework::LoDTensor *cond_; - int sub_block_; - const VariableNameMap inputs_; - const VariableNameMap outputs_; - const Scope scope_; + framework::BlockDesc *sub_block_; }; DECLARE_KERNEL(While, WhileParam); diff --git a/src/operators/lookup_op.h b/src/operators/lookup_op.h index b5c3886cf46c9641e919aee32e7af30c6528309a..e99936a71146328d77853ad88cd7a3fc5d4faf13 100644 --- a/src/operators/lookup_op.h +++ b/src/operators/lookup_op.h @@ -33,7 +33,7 @@ class LookupOp : public framework::OperatorWithKernel< public: LookupOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::LookupKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/lrn_op.h b/src/operators/lrn_op.h index 3e1e92bfe6d9b888f100d07edaabfe0f8c6eaca5..dde4b968af7481f2c8ffaff5542d0fceecc06825 100644 --- a/src/operators/lrn_op.h +++ b/src/operators/lrn_op.h @@ -31,7 +31,7 @@ class LrnOp : public framework::OperatorWithKernel< public: LrnOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::LrnKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/math/activation.h b/src/operators/math/activation.h index 90b9ab4c3a558a994370ea80693e1d31687bb44e..fb90a35516d8c461a05328f65bce24a2b8aa519f 100644 --- a/src/operators/math/activation.h +++ b/src/operators/math/activation.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "common/types.h" #if defined(__ARM_NEON__) || defined(__ARM_NEON) #include -#include "operators/math/math_func_neon.h" +#include "operators/math/math.h" #endif namespace paddle_mobile { diff --git a/src/operators/math/channel_wise.h b/src/operators/math/channel_wise.h new file mode 100644 index 0000000000000000000000000000000000000000..e4c0cbe05bfabde42df7f33a71882aa8ec08c477 --- /dev/null +++ b/src/operators/math/channel_wise.h @@ -0,0 +1,138 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/tensor.h" +#include "operators/math/activation.h" +#ifdef __ARM_NEON +#include +#endif + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +void AddChannelWise(const framework::Tensor *input, + const framework::Tensor *bias, framework::Tensor *output) { + const float *input_ptr = input->data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); + // maybe check shape + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + int spatial_size = input->dims()[2] * input->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * spatial_size; + const float *x = input_ptr + offset; + float *y = output_ptr + offset; + float beta = bias_ptr[channel]; + int j = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4_t __bias = vdupq_n_f32(beta); + for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t in1 = vld1q_f32(x + 4); + float32x4_t in2 = vld1q_f32(x + 8); + float32x4_t in3 = vld1q_f32(x + 12); + in0 = vaddq_f32(__bias, in0); + in1 = vaddq_f32(__bias, in1); + in2 = vaddq_f32(__bias, in2); + in3 = vaddq_f32(__bias, in3); + in0 = math::vActiveq_f32(in0); + in1 = math::vActiveq_f32(in1); + in2 = math::vActiveq_f32(in2); + in3 = math::vActiveq_f32(in3); + vst1q_f32(y, in0); + vst1q_f32(y + 4, in1); + vst1q_f32(y + 8, in2); + vst1q_f32(y + 12, in3); + } + for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { + float32x4_t in0 = vld1q_f32(x); + in0 = vaddq_f32(__bias, in0); + in0 = math::vActiveq_f32(in0); + vst1q_f32(y, in0); + } +#endif + for (; j < spatial_size; ++j, ++x, ++y) { + *y = math::Active((*x) + beta); + } + } + } +} + +template +void ScaleAddChannelWise(const framework::Tensor *input, + const framework::Tensor *scale, + const framework::Tensor *bias, + framework::Tensor *output) { + const float *input_ptr = input->data(); + const float *scale_ptr = scale->data(); + const float *bias_ptr = bias->data(); + float *output_ptr = output->mutable_data(); + // maybe check shape + int batch_size = input->dims()[0]; + int channels = input->dims()[1]; + int spatial_size = input->dims()[2] * input->dims()[3]; + + for (int batch = 0; batch < batch_size; ++batch) { + for (int channel = 0; channel < channels; ++channel) { + size_t offset = (batch * channels + channel) * spatial_size; + const float *x = input_ptr + offset; + float *y = output_ptr + offset; + float alpha = scale_ptr[channel]; + float beta = bias_ptr[channel]; + int j = 0; +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + float32x4_t __scale = vdupq_n_f32(alpha); + float32x4_t __bias = vdupq_n_f32(beta); + for (; j < spatial_size - 15; j += 16, x += 16, y += 16) { + float32x4_t in0 = vld1q_f32(x); + float32x4_t in1 = vld1q_f32(x + 4); + float32x4_t in2 = vld1q_f32(x + 8); + float32x4_t in3 = vld1q_f32(x + 12); + in0 = vmlaq_f32(__bias, __scale, in0); + in1 = vmlaq_f32(__bias, __scale, in1); + in2 = vmlaq_f32(__bias, __scale, in2); + in3 = vmlaq_f32(__bias, __scale, in3); + in0 = math::vActiveq_f32(in0); + in1 = math::vActiveq_f32(in1); + in2 = math::vActiveq_f32(in2); + in3 = math::vActiveq_f32(in3); + vst1q_f32(y, in0); + vst1q_f32(y + 4, in1); + vst1q_f32(y + 8, in2); + vst1q_f32(y + 12, in3); + } + for (; j < spatial_size - 3; j += 4, x += 4, y += 4) { + float32x4_t in0 = vld1q_f32(x); + in0 = vmlaq_f32(__bias, __scale, in0); + in0 = math::vActiveq_f32(in0); + vst1q_f32(y, in0); + } +#endif + for (; j < spatial_size; ++j, ++x, ++y) { + *y = math::Active(alpha * (*x) + beta); + } + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/conv_func.h b/src/operators/math/conv_func.h deleted file mode 100644 index d9e2da0db5c50e0b0f9b11d5584bfce8b75777cd..0000000000000000000000000000000000000000 --- a/src/operators/math/conv_func.h +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#ifdef __ARM_NEON -#include -#endif - -#include "framework/ddim.h" -#include "framework/tensor.h" - -namespace paddle_mobile { -namespace operators { -namespace math { - -using framework::DDim; -using framework::Tensor; - -inline int ConvOutputSize(int input_size, int filter_size, int dilation, - int padding, int stride) { - const int dkernel = dilation * (filter_size - 1) + 1; - int output_size = (input_size + 2 * padding - dkernel) / stride + 1; - return output_size; -} - -inline void expand_bias(Tensor &bias, int axis, const DDim &dDim) { - auto bias_ptr = bias.data(); - const DDim bias_ddim = bias.dims(); - PADDLE_MOBILE_ENFORCE(bias.dims().size() == 1, - "the bias tensor's dims size != 1") - DDim outer_ddim = paddle_mobile::framework::slice_ddim(dDim, 0, axis + 1); - DDim inner_ddim = - paddle_mobile::framework::slice_ddim(dDim, axis + 1, dDim.size()); - int outer_size = paddle_mobile::framework::product(outer_ddim); - int inner_size = paddle_mobile::framework::product(inner_ddim); - bias.Resize(dDim); - auto new_ptr = bias.mutable_data(); - int axis_size = dDim[axis]; - -#ifdef __ARM_NEON - for (int i = 0; i < outer_size; ++i) { - int inner_num = inner_size >> 4; - int remain = inner_size - (inner_num << 4); - float v_bias = bias_ptr[i * axis_size / outer_size]; - for (; inner_num > 0; inner_num--) { - float32x4_t v_newptr1 = vdupq_n_f32(v_bias); - float32x4_t v_newptr2 = vdupq_n_f32(v_bias); - float32x4_t v_newptr3 = vdupq_n_f32(v_bias); - float32x4_t v_newptr4 = vdupq_n_f32(v_bias); - vst1q_f32(new_ptr, v_newptr1); - new_ptr += 4; - vst1q_f32(new_ptr, v_newptr2); - new_ptr += 4; - vst1q_f32(new_ptr, v_newptr3); - new_ptr += 4; - vst1q_f32(new_ptr, v_newptr4); - new_ptr += 4; - } - for (; remain > 0; remain--) { - *new_ptr = v_bias; - new_ptr++; - } - } -#else - for (int i = 0; i < outer_size; ++i) { - float v_bias = bias_ptr[i * axis_size / outer_size]; - for (int j = 0; j < inner_size; ++j) { - new_ptr[i * inner_size + j] = v_bias; - } - } -#endif -} - -inline bool IsExpand(const std::vector &filter_dim, - const std::vector &strides, - const std::vector &paddings, - const std::vector &dilations) { - bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; - for (size_t j = 0; j < strides.size(); ++j) { - filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); - strides_1 = strides_1 && (strides[j] == 1); - padding_0 = padding_0 && (paddings[j] == 0); - dilation_1 = dilation_1 && (dilations[j] == 1); - } - - return !(filter_1 && strides_1 && padding_0 && dilation_1); -} - -} // namespace math -} // namespace operators -} // namespace paddle_mobile diff --git a/src/operators/math/depthwise_conv3x3.cpp b/src/operators/math/depthwise_conv3x3.cpp index 8220e20429ef3b26acb1f0f130ecd41f2954a3c2..62fae35060c97e52142143fcc87b7571b13b1959 100644 --- a/src/operators/math/depthwise_conv3x3.cpp +++ b/src/operators/math/depthwise_conv3x3.cpp @@ -12,2070 +12,1047 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + #include "operators/math/depthwise_conv3x3.h" -#include -#if __ARM_NEON #include -#endif namespace paddle_mobile { namespace operators { namespace math { -void DepthwiseConv3x3(const framework::Tensor *input, - const std::vector &strides, - const std::vector &paddings, - const framework::Tensor *filter, framework::Tensor *bias, - framework::Tensor *output, bool if_bias) { - const int batch_size = input->dims()[0]; - - const int input_height = input->dims()[2]; - - const int input_width = input->dims()[3]; - - const int output_channels = output->dims()[1]; - - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; - const int _kernel_size = 3; - const int stride_height = strides[0]; - const int stride_width = strides[1]; - const int padding_height = paddings[0]; - const int padding_width = paddings[1]; - const float zero = 0; - const int input_channel_stride = input_height * input_width; - const int output_channel_stride = output_height * output_width; - const int filter_channel_stride = 9; - - const float *input_data = input->data(); - const float *filter_data = filter->data(); - if (if_bias) { - math::expand_bias(*bias, 1, output->dims()); - output->ShareDataWith(*bias); - } - float *output_data = output->mutable_data(); - - const int input_batch_stride = output_channels * input_channel_stride; - const int output_batch_stride = output_channels * output_channel_stride; - const int filter_batch_stride = output_channels * output_channel_stride; - const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr; - int hstart, wstart, hend, wend; - float result; - for (int i = 0; i < batch_size; ++i) { - for (int c = 0; c < output_channels; ++c) { - filter1 = filter_data; - filter2 = filter1 + 3; - filter3 = filter2 + 3; - - for (int ph = 0; ph < output_height; ph++) { - for (int pw = 0; pw < output_width; pw++) { - hstart = ph * stride_height - padding_height; - wstart = pw * stride_width - padding_width; - hend = std::min(hstart + _kernel_size, input_height + padding_height); - wend = std::min(wstart + _kernel_size, input_width + padding_width); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, input_height); - wend = std::min(wend, input_width); - pos1 = input_data + hstart * input_width + wstart; - pos2 = input_data + (hstart + 1) * input_width + wstart; - pos3 = input_data + (hstart + 2) * input_width + wstart; - output_ptr = output_data + ph * output_width + pw; - - if (hend - hstart != 3 || wend - wstart != 3) { - result = 0; - float fake_input[9] = {0}; - if (hstart == 0 && wstart == 0) { - // 左上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k - - (3 - wend)]; - } - } - } - } else if (hstart == 0 && wend == input_width) { - // 右上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height && wstart == 0) { - // 左下角 - - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - 1 - hstart && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k - (3 - wend)]; - } - } - } - } else if (hend == input_height && wend == input_width) { - // 右下角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1 && - k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } else if (hstart == 0) { - // 顶部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height) { - // 底部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - - } else if (wstart == 0) { - // 左侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + - (k - (3 - wend))]; - } - } - } - - } else if (wend == input_width) { - // 右侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } - for (int l = 0; l < 9; ++l) { - result += fake_input[l] * filter1[l]; - } - if (if_bias) { - output_data[ph * output_width + pw] += result; - } else { - output_data[ph * output_width + pw] = result; - } - - } else { -#if __ARM_NEON -#if __aarch64__ - const float32x4_t data1 = vld1q_f32(pos1); - const float32x4_t data2 = vld1q_f32(pos2); - const float32x4_t data3 = vld1q_f32(pos3); - - const float32x4_t v_filter1 = vld1q_f32(filter1); - const float32x4_t v_filter2 = vld1q_f32(filter2); - const float32x4_t v_filter3 = vld1q_f32(filter3); - float32x4_t mula = vmulq_f32(data1, v_filter1); - mula = vmlaq_f32(mula, data2, v_filter2); - mula = vmlaq_f32(mula, data3, v_filter3); - float32x2_t res = vpadd_f32( - vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula)); - res = vpadd_f32(res, res); - if (if_bias) { - output_data[ph * output_width + pw] += vget_lane_f32(res, 0); - } else { - output_data[ph * output_width + pw] = vget_lane_f32(res, 0); - } -#else - asm volatile( - - "vld1.32 {q1}, [%[pos1]] \n\t" - "vld1.32 {q4}, [%[filter1]] \n\t" - "vmov.f32 q0, #0.0 \n\t" - - "vld1.32 {q2}, [%[pos2]] \n\t" - "vld1.32 {q5}, [%[filter2]] \n\t" - "vmla.f32 q0, q1, q4 \n\t" - - "vld1.32 {q3}, [%[pos3]] \n\t" - "vld1.32 {q6}, [%[filter3]] \n\t" - - "vmla.f32 q0, q2, q5 \n\t" - "vmla.f32 q0, q3, q6 \n\t" - - "vmov.f32 d1[1], %[zero] \n\t" - - "vadd.f32 d4, d0, d1 \n\t" - "vadd.f32 s10, s8, s9 \n\t" - "vst1.32 {d5[0]},[%[output_ptr]] \n\t" - : - : [input_data] "r"(input_data), [pos1] "r"(pos1), - [pos2] "r"(pos2), [pos3] "r"(pos3), [filter1] "r"(filter1), - [filter2] "r"(filter2), [filter3] "r"(filter3), - [output_ptr] "r"(output_ptr), [zero] "r"(zero) - : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6"); -#endif // __aarch64__ -#else - -#endif // __ARM_NEON - } - } - } - input_data += input_channel_stride; - output_data += output_channel_stride; - filter_data += filter_channel_stride; - } - input_data += input_batch_stride; - output_data += output_batch_stride; - } +#ifndef __aarch64__ +inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) { + float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0)); + float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1)); + return vcombine_f32(sum0, sum1); } +#endif -void DepthwiseConv3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu) { -#if __ARM_NEON - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); - const int hxw = h * w; - // const int l = h; - - // leftTop, rightTop, leftBottom, rightBottom - const int lt = 0; - const int rt = w - 1; - const int lb = (h - 1) * w; - const int rb = h * w - 1; - - const float *bias_data; - if (if_bias) { - bias_data = bias->data(); - } - float32x4_t zero = vdupq_n_f32(0.0); - - for (int b = 0; b < batch_size; ++b) { -#pragma omp parallel for - for (int j = 0; j < c; ++j) { - const float *filter_data_tmp = filter->data() + j * 9; - const float *input_data = input->data() + j * hxw; - float *output_data = output->mutable_data() + j * hxw; - float32x4_t vbias; - if (if_bias) { - vbias = vdupq_n_f32(bias_data[j]); - } - - int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - output_data[lt] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[w] + w22 * input_data[w + 1]; - output_data[rt] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + - w20 * input_data[2 * w - 2] + - w21 * input_data[2 * w - 1]; - output_data[lb] = - w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + - w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; - output_data[rb] = - w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + - w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; - if (if_bias) { - output_data[lt] += bias_data[j]; - output_data[rt] += bias_data[j]; - output_data[lb] += bias_data[j]; - output_data[rb] += bias_data[j]; - } - if (if_relu) { - output_data[lt] = output_data[lt] < 0 ? 0 : output_data[lt]; - output_data[rt] = output_data[rt] < 0 ? 0 : output_data[rt]; - output_data[lb] = output_data[lb] < 0 ? 0 : output_data[lb]; - output_data[rb] = output_data[rb] < 0 ? 0 : output_data[rb]; - } - - for (int i = 1; i < h - 1; ++i) { - int left = i * w; - int right = i * w + w - 1; - output_data[left] = - w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + - w11 * input_data[i * w] + w12 * input_data[i * w + 1] + - w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; - - output_data[right] = w00 * input_data[i * w + w - 1 - w - 1] + - w01 * input_data[i * w + w - 1 - w] + - w10 * input_data[i * w + w - 1 - 1] + - w11 * input_data[i * w + w - 1] + - w20 * input_data[i * w + w - 1 + w - 1] + - w21 * input_data[i * w + w - 1 + w]; - if (if_bias) { - output_data[left] += bias_data[j]; - output_data[right] += bias_data[j]; - } - if (if_relu) { - output_data[left] = output_data[left] < 0 ? 0 : output_data[left]; - output_data[right] = output_data[right] < 0 ? 0 : output_data[right]; - } - } - - // top 1 row and bottom 1 row - const float *input_tmp = input_data; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, tmp2, - tmp3, tmp4, tmp5, out0; - in0 = vld1q_f32(input_tmp); - in2 = vld1q_f32(input_tmp + w); - const float *input_tmp_end = input_tmp + (h - 2) * w; - in4 = vld1q_f32(input_tmp_end); - in6 = vld1q_f32(input_tmp_end + w); - int c_mid = w_mid; - auto output_ptr = output_data + 1; - for (; c_mid > 3; c_mid -= 4) { - in1 = vld1q_f32(input_tmp + 4); - in3 = vld1q_f32(input_tmp + w + 4); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - vst1q_f32(output_ptr, out0); - - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + w + 4); - - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - vst1q_f32(output_ptr + (h - 1) * w, out0); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; - } - - // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - - // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); - float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); - - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 1); - tmp3 = vextq_f32(in6, pad3, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); - } - } - // mid - - for (int i = 0; i < h - 2; ++i) { - auto output_ptr = output_data + (i + 1) * w + 1; - input_tmp = input_data + i * w; - auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + w); - auto in4_tmp = vld1q_f32(input_tmp + w + w); - c_mid = w_mid; - for (; c_mid > 3; c_mid -= 4) { - auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + w + 4); - auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); - - tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); - tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); - tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); - tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); - tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); - tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - - vst1q_f32(output_ptr, out0); +template +inline void Depth3x3NormalRowLoadInput(const float *input, float32x4_t *y) { + y[0] = vld1q_f32(input); + y[2] = vld1q_f32(input + 4); + y[1] = vextq_f32(y[0], y[2], 1); + y[2] = vextq_f32(y[0], y[2], 2); +} - output_ptr += 4; - input_tmp += 4; - in0_tmp = in1_tmp; - in2_tmp = in3_tmp; - in4_tmp = in5_tmp; - } +template <> +inline void Depth3x3NormalRowLoadInput<2>(const float *input, float32x4_t *y) { + float32x4x2_t x = vld2q_f32(input); + y[0] = x.val[0]; + y[1] = x.val[1]; + y[2] = vextq_f32(y[0], y[0], 1); + y[2] = vsetq_lane_f32(input[8], y[2], 3); +} - float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); - - tmp0 = vextq_f32(in0_tmp, pad0, 1); - tmp1 = vextq_f32(in0_tmp, pad0, 2); - tmp2 = vextq_f32(in2_tmp, pad1, 1); - tmp3 = vextq_f32(in2_tmp, pad1, 2); - tmp4 = vextq_f32(in4_tmp, pad2, 1); - tmp5 = vextq_f32(in4_tmp, pad2, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vaddq_f32(out0, vbias); - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } +#define DEPTHWISE_CONV3X3_NORMAL_BORDER(start, end) \ + for (int w = start; w < end; ++w) { \ + const int w_in_start = -padding_w + w * Stride_w; \ + const int w_in_end = w_in_start + 3; \ + const int w_start = w_in_start > 0 ? w_in_start : 0; \ + const int w_end = w_in_end < input_w ? w_in_end : input_w; \ + float value = 0; \ + for (int h_in = h_start; h_in < h_end; ++h_in) { \ + for (int w_in = w_start; w_in < w_end; ++w_in) { \ + value += filter[(h_in - h_in_start) * 3 + (w_in - w_in_start)] * \ + input[h_in * input_w + w_in]; \ + } \ + } \ + output_ptr[w] = value; \ + } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - } +template +inline void DepthwiseConv3x3NormalRow(const float *input, const float *filter, + const int h_output, const int input_h, + const int input_w, const int padding_h, + const int padding_w, const int output_w, + float *output, float32x4_t *ker) { + const int h_in_start = -padding_h + h_output * Stride_h; + const int h_in_end = h_in_start + 3; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int h_end = h_in_end < input_h ? h_in_end : input_h; + + int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; + int valid_w_end = (input_w + padding_w - 3) / Stride_w + 1; + if (valid_w_end < valid_w_start) { + valid_w_end = valid_w_start; + } + // const int valid_w_end = output_w - valid_w_start; + float *output_ptr = output + h_output * output_w; + // border left + DEPTHWISE_CONV3X3_NORMAL_BORDER(0, valid_w_start) + // middle + int output_tiles = (valid_w_end - valid_w_start) >> 2; + float32x4_t _sum, _x[3]; + // valid w + for (int w = 0; w < output_tiles * 4; w += 4) { + _sum = vdupq_n_f32(0.f); + int output_offset = valid_w_start + w; + int input_w_offset = output_offset * Stride_w - padding_w; + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth3x3NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1); + _sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0); } + vst1q_f32(output_ptr + output_offset, _sum); } -#endif + // remain valid w + int remain = (valid_w_end - valid_w_start) & 0x3; + if (remain > 0) { + _sum = vdupq_n_f32(0.f); + int remain_start = valid_w_start + (output_tiles << 2); + int input_w_offset = remain_start * Stride_w - padding_w; + float *output_ptr0 = output_ptr + remain_start; + + for (int h_in = h_start; h_in < h_end; ++h_in) { + int index = h_in - h_in_start; + Depth3x3NormalRowLoadInput( + input + h_in * input_w + input_w_offset, _x); + _sum = vmlaq_lane_f32(_sum, _x[0], vget_low_f32(ker[index]), 0); + _sum = vmlaq_lane_f32(_sum, _x[1], vget_low_f32(ker[index]), 1); + _sum = vmlaq_lane_f32(_sum, _x[2], vget_high_f32(ker[index]), 0); + } + switch (remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _sum, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_sum)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _sum, 0); + break; + } + } + // border right + DEPTHWISE_CONV3X3_NORMAL_BORDER(valid_w_end, output_w) } -void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu) { -#if __ARM_NEON - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->mutable_data(); - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const int batch_size = static_cast(input->dims()[0]); - const int input_channel = static_cast(input->dims()[1]); - - const int input_height = static_cast(input->dims()[2]); - const int input_width = static_cast(input->dims()[3]); - const int output_height = static_cast(output->dims()[2]); - const int output_width = static_cast(output->dims()[3]); - - const int hxw = input_height * input_width; - - // const int l = input_height; - const int h = input_height; - const int w = input_width; - float32x4_t vzero = vdupq_n_f32(0); - - for (int b = 0; b < batch_size; b++) { -#pragma omp parallel for - for (int c = 0; c < input_channel; c++) { - const float *filter_data = filter->data() + c * 9; - const float *input_data = input->data() + c * hxw; - float *output_data = output->data() + c * hxw; - float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]); - float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]); - - float w00 = filter_data[0]; - float w01 = filter_data[1]; - float w02 = filter_data[2]; - float w10 = filter_data[3]; - float w11 = filter_data[4]; - float w12 = filter_data[5]; - float w20 = filter_data[6]; - float w21 = filter_data[7]; - float w22 = filter_data[8]; - - for (int i = 1; i < output_height - 1; i++) { - float *output_ptr; - float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, tmp4, - tmp5, out0; - for (int m = 1; m < output_width - 4; m += 4) { - output_ptr = output_data + i * output_width + m; - in0 = vld1q_f32(input_data + (i - 1) * input_width + m - 1); - in1 = vld1q_f32(input_data + (i - 1) * input_width + m + 3); - in2 = vld1q_f32(input_data + i * input_width + m - 1); - in3 = vld1q_f32(input_data + i * input_width + m + 3); - in4 = vld1q_f32(input_data + (i + 1) * input_width + m - 1); - in5 = vld1q_f32(input_data + (i + 1) * input_width + m + 3); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - tmp4 = vextq_f32(in4, in5, 1); - tmp5 = vextq_f32(in4, in5, 2); - - out0 = vmulq_n_f32(in0, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr, out0); - } - int m; - for (m = 1; (m + 3) < output_width - 1; m = m + 4) { - } +template <> +void DepthwiseConv3x3S1(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) { + const float *input_data = input.data(); + const float *filter_data = filter.data(); + float *out_data = output->mutable_data(); + + const int input_h = input.dims()[2]; + const int input_w = input.dims()[3]; + const int output_h = output->dims()[2]; + const int output_w = output->dims()[3]; + const int padding_h = paddings[0]; + const int padding_w = paddings[1]; + const int image_size = input_h * input_w; + const int out_image_size = output_h * output_w; + const int valid_h_start = padding_h; + const int valid_h_end = output_h - valid_h_start; + const int valid_h = valid_h_end - valid_h_start; + const int valid_w_start = padding_w; + const int valid_w_end = output_w - valid_w_start; + const int valid_w = valid_w_end - valid_w_start; + + #pragma omp parallel for + for (int g = 0; g < input.dims()[1]; ++g) { + const float *input_ptr = input_data + g * image_size; + const float *filter_ptr = filter_data + g * 9; + float *output_ptr = out_data + g * out_image_size; + + const float *filter_ptr0 = filter_ptr; + const float *filter_ptr1 = filter_ptr0 + 3; + const float *filter_ptr2 = filter_ptr1 + 3; + float32x4_t _ker[3]; + _ker[0] = vld1q_f32(filter_ptr0); + _ker[1] = vld1q_f32(filter_ptr1); + _ker[2] = vld1q_f32(filter_ptr2); + + // pad top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } - for (int j = m; j < output_width - 1; j++) { - output_data[i * output_width + j] = - input_data[(i - 1) * input_width + j - 1] * w00 + - input_data[(i - 1) * input_width + j] * w01 + - input_data[(i - 1) * input_width + j + 1] * w02 + - input_data[(i)*input_width + j - 1] * w10 + - input_data[(i)*input_width + j] * w11 + - input_data[(i)*input_width + j + 1] * w12 + - input_data[(i + 1) * input_width + j - 1] * w20 + - input_data[(i + 1) * input_width + j] * w21 + - input_data[(i + 1) * input_width + j + 1] * w22; - output_data[i * output_width + j] = - newscale_data[c] * output_data[i * output_width + j] + - newbias_data[c]; - if (if_relu) { - output_data[i * output_width + j] = - output_data[i * output_width + j] < 0 - ? 0 - : output_data[i * output_width + j]; + // output 2x6 + int output_w_tiles = valid_w / 6; + int output_w_remain = valid_w - output_w_tiles * 6; + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) { + const float *input_ptr0 = input_ptr + (h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + float *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t zero = vdupq_n_f32(0.f); + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + row3 = vextq_f32(zero, row3, 3); + float32x4_t acc0, acc1; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0.f; + output_ptr1[w] = 0.f; + } else { + acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc0 = vextq_f32(acc0, acc0, 1); + acc1 = vmulq_f32(row1, _ker[0]); + acc1 = vmlaq_f32(acc1, row2, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[2]); + acc1 = vextq_f32(acc1, acc1, 1); + float32x2_t sum = vpadd_f32(vget_low_f32(acc0), vget_low_f32(acc1)); + vst1_lane_f32(output_ptr0 + w, sum, 0); + vst1_lane_f32(output_ptr1 + w, sum, 1); + + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + row3 = vextq_f32(zero, row3, 3); } } + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; } - - output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[w] + w22 * input_data[w + 1]; - output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - 1] + - w20 * input_data[2 * w - 2] + - w21 * input_data[2 * w - 1]; - output_data[(h - 1) * w] = - w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + 1] + - w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; - output_data[h * w - 1] = - w00 * input_data[h * w - w - 2] + w01 * input_data[h * w - w - 1] + - w10 * input_data[h * w - 2] + w11 * input_data[h * w - 1]; - output_data[0] = output_data[0] * newscale_data[c] + newbias_data[c]; - output_data[w - 1] = - output_data[w - 1] * newscale_data[c] + newbias_data[c]; - output_data[(h - 1) * w] = - output_data[(h - 1) * w] * newscale_data[c] + newbias_data[c]; - output_data[h * w - 1] = - output_data[h * w - 1] * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - 1]; - output_data[(h - 1) * w] = - output_data[(h - 1) * w] < 0 ? 0 : output_data[(h - 1) * w]; - output_data[h * w - 1] = - output_data[h * w - 1] < 0 ? 0 : output_data[h * w - 1]; - } - for (int i = 1; i < h - 1; ++i) { - output_data[i * w] = - w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] + - w11 * input_data[i * w] + w12 * input_data[i * w + 1] + - w21 * input_data[i * w + w] + w22 * input_data[i * w + w + 1]; - - output_data[i * w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + - w01 * input_data[i * w + w - 1 - w] + - w10 * input_data[i * w + w - 1 - 1] + - w11 * input_data[i * w + w - 1] + - w20 * input_data[i * w + w - 1 + w - 1] + - w21 * input_data[i * w + w - 1 + w]; - output_data[i * w] = - output_data[i * w] * newscale_data[c] + newbias_data[c]; - output_data[i * w + w - 1] = - output_data[i * w + w - 1] * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i * w]; - output_data[i * w + w - 1] = - output_data[i * w + w - 1] < 0 ? 0 : output_data[i * w + w - 1]; - } + // valid + float32x4_t _result0, _result1, _result2, _result3; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0); + _result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + _row10 = vld1q_f32(input_ptr3); + _row11 = vld1q_f32(input_ptr3 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + _result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + vst1_f32(output_ptr0 + 4, vget_low_f32(_result1)); + vst1q_f32(output_ptr1, _result2); + vst1_f32(output_ptr1 + 4, vget_low_f32(_result3)); + + input_ptr0 += 6; + input_ptr1 += 6; + input_ptr2 += 6; + input_ptr3 += 6; + output_ptr0 += 6; + output_ptr1 += 6; } - - int m; - for (m = 1; m < output_width - 4; m += 4) { - float *output_ptr = output_data + m; - float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; - in0 = vld1q_f32(input_data + m - 1); - in1 = vld1q_f32(input_data + m + 3); - in2 = vld1q_f32(input_data + input_width + m - 1); - in3 = vld1q_f32(input_data + input_width + m + 3); - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); + // remain w + if (output_w_remain > 0) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _result2 = vmulq_lane_f32(_row10, vget_low_f32(_ker[0]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[0]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[0]), 0); + _result3 = vmulq_lane_f32(_row11, vget_low_f32(_ker[0]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[0]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[0]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + _row10 = vld1q_f32(input_ptr3); + _row11 = vld1q_f32(input_ptr3 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + _result2 = vmlaq_lane_f32(_result2, _row00, vget_low_f32(_ker[1]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[1]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _row01, vget_low_f32(_ker[1]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[1]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[1]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result2 = vmlaq_lane_f32(_result2, _row10, vget_low_f32(_ker[2]), 0); + _result2 = vmlaq_lane_f32(_result2, _ext01, vget_low_f32(_ker[2]), 1); + _result2 = vmlaq_lane_f32(_result2, _ext02, vget_high_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _row11, vget_low_f32(_ker[2]), 0); + _result3 = vmlaq_lane_f32(_result3, _ext03, vget_low_f32(_ker[2]), 1); + _result3 = vmlaq_lane_f32(_result3, _ext04, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 5: + vst1q_lane_f32(output_ptr0 + 4, _result1, 0); + vst1q_lane_f32(output_ptr1 + 4, _result3, 0); + case 4: + vst1q_f32(output_ptr0, _result0); + vst1q_f32(output_ptr1, _result2); + break; + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + vst1q_lane_f32(output_ptr1 + 2, _result2, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + vst1_f32(output_ptr1, vget_low_f32(_result2)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + vst1q_lane_f32(output_ptr1, _result2, 0); + break; } - vst1q_f32(output_ptr, out0); - } - for (m = 1; (m + 3) < output_width - 1; m += 4) { + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; } - for (int j = m; j < output_width - 1; j++) { - output_data[j] = input_data[j - 1] * w10 + input_data[j] * w11 + - input_data[j + 1] * w12 + - input_data[input_width + j - 1] * w20 + - input_data[input_width + j] * w21 + - input_data[input_width + j + 1] * w22; - output_data[j] = output_data[j] * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; + // pad right + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t row3 = vld1_f32(input_ptr3); + float32x2_t zero = vdup_n_f32(0.f); + float32x2_t acc0, acc1; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; + *output_ptr1 = 0.f; + } else { + acc0 = vmul_f32(row0, vget_low_f32(_ker[0])); + acc0 = vmla_f32(acc0, row1, vget_low_f32(_ker[1])); + acc0 = vmla_f32(acc0, row2, vget_low_f32(_ker[2])); + acc1 = vmul_f32(row1, vget_low_f32(_ker[0])); + acc1 = vmla_f32(acc1, row2, vget_low_f32(_ker[1])); + acc1 = vmla_f32(acc1, row3, vget_low_f32(_ker[2])); + float32x2_t sum = vpadd_f32(acc0, acc1); + vst1_lane_f32(output_ptr0, sum, 0); + vst1_lane_f32(output_ptr1, sum, 1); + row0 = vext_f32(row0, zero, 1); + row1 = vext_f32(row1, zero, 1); + row2 = vext_f32(row2, zero, 1); + row3 = vext_f32(row3, zero, 1); + } + output_ptr0++; + output_ptr1++; } } - - for (m = 1; m < output_width - 4; m += 4) { - float *output_ptr = - output_data + (output_height - 1) * output_width + m; - - float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; - in0 = vld1q_f32(input_data + (output_height - 2) * input_width + m - 1); - in1 = vld1q_f32(input_data + (output_height - 2) * input_width + m + 3); - in2 = vld1q_f32(input_data + (output_height - 1) * input_width + m - 1); - in3 = vld1q_f32(input_data + (output_height - 1) * input_width + m + 3); - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - out0 = vmulq_n_f32(in0, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); + } + // remain height + int start_h = valid_h_start + (valid_h & 0xfffe); + if (start_h < valid_h_end) { + const float *input_ptr0 = input_ptr + (start_h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + float *output_ptr0 = output_ptr + start_h * output_w; + // pad left + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t zero = vdupq_n_f32(0.f); + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + float32x4_t acc; + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - w; + if (padding >= 3) { + output_ptr0[w] = 0.f; + } else { + acc = vmulq_f32(row0, _ker[0]); + acc = vmlaq_f32(acc, row1, _ker[1]); + acc = vmlaq_f32(acc, row2, _ker[2]); + acc = vextq_f32(acc, acc, 1); + float32x2_t sum = vpadd_f32(vget_low_f32(acc), vget_low_f32(acc)); + vst1_lane_f32(output_ptr0 + w, sum, 0); + + row0 = vextq_f32(zero, row0, 3); + row1 = vextq_f32(zero, row1, 3); + row2 = vextq_f32(zero, row2, 3); + } } - vst1q_f32(output_ptr, out0); + output_ptr0 += valid_w_start; } - for (m = 1; (m + 3) < output_width - 1; m = m + 4) { + // valid + float32x4_t _result0, _result1; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + vst1_f32(output_ptr0 + 4, vget_low_f32(_result1)); + + input_ptr0 += 6; + input_ptr1 += 6; + input_ptr2 += 6; + output_ptr0 += 6; } - for (int j = m; j < output_width - 1; j++) { - output_data[(output_height - 1) * input_width + j] = - input_data[(output_height - 2) * input_width + j - 1] * w00 + - input_data[(output_height - 2) * input_width + j] * w01 + - input_data[(output_height - 2) * input_width + j + 1] * w02 + - input_data[(output_height - 1) * input_width + j - 1] * w10 + - input_data[(output_height - 1) * input_width + j] * w11 + - input_data[(output_height - 1) * input_width + j + 1] * w12; - output_data[(output_height - 1) * output_width + j] = - output_data[(output_height - 1) * output_width + j] * - newscale_data[c] + - newbias_data[c]; - - if (if_relu) { - output_data[(output_height - 1) * output_width + j] = - output_data[(output_height - 1) * output_width + j] < 0 - ? 0 - : output_data[(output_height - 1) * output_width + j]; - } - } - } - } - - /* - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->data(); - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const int h = static_cast(input->dims()[2]); - const int w = static_cast(input->dims()[3]); -// const int l = h; - - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int hxw = h * w; - float32x4_t vnewbias = vdupq_n_f32(0.0); - float32x4_t vnewscale = vdupq_n_f32(1.0); - float32x4_t vzero = vdupq_n_f32(0); - - for (int b = 0; b < batch_size; ++b) { - const float *filter_data_tmp = filter_data; - - for (int j = 0; j < c; ++j) { - vnewbias = vdupq_n_f32(newbias_data[j]); - vnewscale = vdupq_n_f32(newscale_data[j]); - - int w_mid = w - 2; // l=1->l_mid=-1,l=2->l_mid=0 - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - output_data[0] = w11 * input_data[0] + w12 * input_data[1] + - w21 * input_data[w] + w22 * input_data[w + 1]; - - output_data[w - 1] = w10 * input_data[w - 2] + w11 * input_data[w - - 1] + w20 * input_data[2 * w - 2] + w21 * input_data[2 * w - 1]; - - output_data[(h - 1) * w] = - w01 * input_data[(h - 2) * w] + w02 * input_data[(h - 2) * w + - 1] + w11 * input_data[(h - 1) * w] + w12 * input_data[(h - 1) * w + 1]; - output_data[h * w - 1] = w00 * input_data[h*w-w-2] + - w01 * input_data[h*w-w-1] + - w10 * input_data[h * w - 2] + - w11 * input_data[h * w - 1]; - output_data[0] = output_data[0] * newscale_data[j] + - newbias_data[j]; output_data[w - 1] = output_data[w - 1] * - newscale_data[j] + newbias_data[j]; output_data[(h - 1) * w] = - output_data[(h - 1) * w] * newscale_data[j] + newbias_data[j]; - output_data[h * w - 1] = - output_data[h * w - 1] * newscale_data[j] + newbias_data[j]; - - if (if_relu) { - output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - output_data[w - 1] = output_data[w - 1] < 0 ? 0 : output_data[w - - 1]; output_data[(h - 1) * w] = output_data[(h - 1) * w] < 0 ? 0 : - output_data[(h - 1) * w]; output_data[h * w - 1] = output_data[h * w - 1] - < 0 ? 0 : output_data[h * w - 1]; - } - for (int i = 1; i < h - 1; ++i) { - output_data[i * w] = - w01 * input_data[i * w - w] + w02 * input_data[i * w - w + 1] - + w11 * input_data[i * w] + w12 * input_data[i * w + 1] + w21 * - input_data[i * w + w] + w22 * input_data[i * w + w + 1]; output_data[i * - w + w - 1] = w00 * input_data[i * w + w - 1 - w - 1] + w01 * input_data[i - * w + w - 1 - w] + w10 * input_data[i * w + w - 1 - 1] + w11 * - input_data[i * w + w - 1] + w20 * input_data[i * w + w - 1 + w - 1] + w21 - * input_data[i * w + w - 1 + w]; output_data[i * w] = output_data[i * w] - * newscale_data[j] + newbias_data[j]; output_data[i * w + w - 1] = - output_data[i * w + w - 1] * newscale_data[j] + - newbias_data[j]; - - if (if_relu) { - output_data[i * w] = output_data[i * w] < 0 ? 0 : output_data[i - * w]; output_data[i * w + w - 1] = output_data[i * w + w - 1] < 0 ? 0 : - output_data[i * w + w - 1]; - } - } - // top 1 row and bottom 1 row - const float *input_tmp = input_data; - - float32x4_t in0, in1, in2, in3, in4, in5, in6, in7, tmp0, tmp1, - tmp2, tmp3, tmp4, tmp5, out0; in0 = vld1q_f32(input_tmp); in2 = - vld1q_f32(input_tmp + w); const float *input_tmp_end = input_tmp + (h - - 2) * w; in4 = vld1q_f32(input_tmp_end); in6 = vld1q_f32(input_tmp_end + - w); int c_mid = w_mid; auto output_ptr = output_data + 1; for (; c_mid > - 3; c_mid -= 4) { in1 = vld1q_f32(input_tmp + 4); in3 = - vld1q_f32(input_tmp + w + 4); - - tmp0 = vextq_f32(in0, in1, 1); - tmp1 = vextq_f32(in0, in1, 2); - - tmp2 = vextq_f32(in2, in3, 1); - tmp3 = vextq_f32(in2, in3, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr, out0); - - in5 = vld1q_f32(input_tmp_end + 4); - in7 = vld1q_f32(input_tmp_end + w + 4); - - tmp0 = vextq_f32(in4, in5, 1); - tmp1 = vextq_f32(in4, in5, 2); - tmp2 = vextq_f32(in6, in7, 1); - tmp3 = vextq_f32(in6, in7, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr + (h - 1) * w, out0); - - // can optimize to each 8 stride. - input_tmp += 4; - input_tmp_end += 4; - output_ptr += 4; - in0 = in1; - in2 = in3; - in4 = in5; - in6 = in7; - } - - // top right pad - float32x4_t pad0 = vdupq_n_f32(input_data[w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[2 * w - 1]); - - tmp0 = vextq_f32(in0, pad0, 1); - tmp1 = vextq_f32(in0, pad0, 2); - tmp2 = vextq_f32(in2, pad1, 1); - tmp3 = vextq_f32(in2, pad1, 2); - - out0 = vmulq_n_f32(in0, w10); - out0 = vmlaq_n_f32(out0, tmp0, w11); - out0 = vmlaq_n_f32(out0, tmp1, w12); - out0 = vmlaq_n_f32(out0, in2, w20); - out0 = vmlaq_n_f32(out0, tmp2, w21); - out0 = vmlaq_n_f32(out0, tmp3, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - - // bottom right pad - float32x4_t pad2 = vdupq_n_f32(input_data[h * w - 1 - w]); - float32x4_t pad3 = vdupq_n_f32(input_data[h * w - 1]); - - tmp0 = vextq_f32(in4, pad2, 1); - tmp1 = vextq_f32(in4, pad2, 2); - tmp2 = vextq_f32(in6, pad3, 1); - tmp3 = vextq_f32(in6, pad3, 2); - - out0 = vmulq_n_f32(in4, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in6, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + (h - 1) * w + i, out0, 2); - } - } - // mid - - - for (int i = 0; i < h - 2; ++i) { - auto output_ptr = output_data + (i + 1) * w + 1; - input_tmp = input_data + i * w; - auto in0_tmp = vld1q_f32(input_tmp); - auto in2_tmp = vld1q_f32(input_tmp + w); - auto in4_tmp = vld1q_f32(input_tmp + w + w); - c_mid = w_mid; - for (; c_mid > 3; c_mid -= 4) { - auto in1_tmp = vld1q_f32(input_tmp + 4); - auto in3_tmp = vld1q_f32(input_tmp + w + 4); - auto in5_tmp = vld1q_f32(input_tmp + w + w + 4); - - tmp0 = vextq_f32(in0_tmp, in1_tmp, 1); - tmp1 = vextq_f32(in0_tmp, in1_tmp, 2); - tmp2 = vextq_f32(in2_tmp, in3_tmp, 1); - tmp3 = vextq_f32(in2_tmp, in3_tmp, 2); - tmp4 = vextq_f32(in4_tmp, in5_tmp, 1); - tmp5 = vextq_f32(in4_tmp, in5_tmp, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - vst1q_f32(output_ptr, out0); - - output_ptr += 4; - input_tmp += 4; - in0_tmp = in1_tmp; - in2_tmp = in3_tmp; - in4_tmp = in5_tmp; - } - - float32x4_t pad0 = vdupq_n_f32(input_data[i * w + w - 1]); - float32x4_t pad1 = vdupq_n_f32(input_data[i * w + w - 1 + w]); - float32x4_t pad2 = vdupq_n_f32(input_data[i * w + w - 1 + w + w]); - - tmp0 = vextq_f32(in0_tmp, pad0, 1); - tmp1 = vextq_f32(in0_tmp, pad0, 2); - tmp2 = vextq_f32(in2_tmp, pad1, 1); - tmp3 = vextq_f32(in2_tmp, pad1, 2); - tmp4 = vextq_f32(in4_tmp, pad2, 1); - tmp5 = vextq_f32(in4_tmp, pad2, 2); - - out0 = vmulq_n_f32(in0_tmp, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2_tmp, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4_tmp, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - out0 = vmlaq_f32(vnewbias, vnewscale, out0); - if (if_relu) { - out0 = vmaxq_f32(out0, vzero); - } - for (int i = 0; i < c_mid; ++i) { - if (i == 0) { - vst1q_lane_f32(output_ptr + i, out0, 0); - } - if (i == 1) { - vst1q_lane_f32(output_ptr + i, out0, 1); - } - if (i == 2) { - vst1q_lane_f32(output_ptr + i, out0, 2); - } - } - } - output_data += hxw; - input_data += hxw; - filter_data_tmp += 9; - } + if (output_w_remain > 0) { + float32x4_t _row00 = vld1q_f32(input_ptr0); + float32x4_t _row01 = vld1q_f32(input_ptr0 + 4); + float32x4_t _row10 = vld1q_f32(input_ptr1); + float32x4_t _row11 = vld1q_f32(input_ptr1 + 4); + + float32x4_t _ext01 = vextq_f32(_row00, _row01, 1); + float32x4_t _ext02 = vextq_f32(_row00, _row01, 2); + float32x4_t _ext03 = vextq_f32(_row01, _row01, 1); + float32x4_t _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmulq_lane_f32(_row00, vget_low_f32(_ker[0]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[0]), 0); + _result1 = vmulq_lane_f32(_row01, vget_low_f32(_ker[0]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[0]), 0); + + _ext01 = vextq_f32(_row10, _row11, 1); + _ext02 = vextq_f32(_row10, _row11, 2); + _ext03 = vextq_f32(_row11, _row11, 1); + _ext04 = vextq_f32(_row11, _row11, 2); + + _result0 = vmlaq_lane_f32(_result0, _row10, vget_low_f32(_ker[1]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _row11, vget_low_f32(_ker[1]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[1]), 0); + + _row00 = vld1q_f32(input_ptr2); + _row01 = vld1q_f32(input_ptr2 + 4); + + _ext01 = vextq_f32(_row00, _row01, 1); + _ext02 = vextq_f32(_row00, _row01, 2); + _ext03 = vextq_f32(_row01, _row01, 1); + _ext04 = vextq_f32(_row01, _row01, 2); + + _result0 = vmlaq_lane_f32(_result0, _row00, vget_low_f32(_ker[2]), 0); + _result0 = vmlaq_lane_f32(_result0, _ext01, vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext02, vget_high_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _row01, vget_low_f32(_ker[2]), 0); + _result1 = vmlaq_lane_f32(_result1, _ext03, vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext04, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 5: + vst1q_lane_f32(output_ptr0 + 4, _result1, 0); + case 4: + vst1q_f32(output_ptr0, _result0); + break; + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + break; } - */ - -#endif -} -/// w!=h not fix -void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu) { -#if __ARM_NEON - - const int batch_size = input->dims()[0]; - - const int input_height = input->dims()[2]; - - const int input_width = input->dims()[3]; - - const int output_channels = output->dims()[1]; - - const int output_height = output->dims()[2]; - const int output_width = output->dims()[3]; - const int _kernel_size = 3; - const int stride_height = 2; - const int stride_width = 2; - const int padding_height = 1; - const int padding_width = 1; - const float zero = 0; - const int input_channel_stride = input_height * input_width; - const int output_channel_stride = output_height * output_width; - const int filter_channel_stride = 9; - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const float *input_data = input->data(); - const float *filter_data = filter->data(); - - float *output_data = output->mutable_data(); - - const int input_batch_stride = output_channels * input_channel_stride; - const int output_batch_stride = output_channels * output_channel_stride; - const int filter_batch_stride = output_channels * output_channel_stride; - const float *pos1, *pos2, *pos3, *filter1, *filter2, *filter3, *output_ptr; - int hstart, wstart, hend, wend; - float result; - for (int i = 0; i < batch_size; ++i) { - for (int c = 0; c < output_channels; ++c) { - filter1 = filter_data; - filter2 = filter1 + 3; - filter3 = filter2 + 3; - - for (int ph = 0; ph < output_height; ph++) { - for (int pw = 0; pw < output_width; pw++) { - hstart = ph * stride_height - padding_height; - wstart = pw * stride_width - padding_width; - hend = std::min(hstart + _kernel_size, input_height + padding_height); - wend = std::min(wstart + _kernel_size, input_width + padding_width); - hstart = std::max(hstart, 0); - wstart = std::max(wstart, 0); - hend = std::min(hend, input_height); - wend = std::min(wend, input_width); - pos1 = input_data + hstart * input_width + wstart; - pos2 = input_data + (hstart + 1) * input_width + wstart; - pos3 = input_data + (hstart + 2) * input_width + wstart; - output_ptr = output_data + ph * output_width + pw; - - if (hend - hstart != 3 || wend - wstart != 3) { - result = 0; - float fake_input[9] = {0}; - if (hstart == 0 && wstart == 0) { - // 左上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k - - (3 - wend)]; - } - } - } - } else if (hstart == 0 && wend == input_width) { - // 右上角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend && k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height && wstart == 0) { - // 左下角 - - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - 1 - hstart && k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k - (3 - wend)]; - } - } - } - } else if (hend == input_height && wend == input_width) { - // 右下角 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1 && - k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } else if (hstart == 0) { - // 顶部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j >= 3 - hend) { - fake_input[3 * j + k] = - input_data[(j - (3 - hend)) * input_width + k + wstart]; - } - } - } - - } else if (hend == input_height) { - // 底部 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (j <= input_height - hstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - - } else if (wstart == 0) { - // 左侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k >= 3 - wend) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + - (k - (3 - wend))]; - } - } - } - - } else if (wend == input_width) { - // 右侧 - for (int j = 0; j < 3; ++j) { - for (int k = 0; k < 3; ++k) { - if (k <= input_width - wstart - 1) { - fake_input[3 * j + k] = - input_data[(j + hstart) * input_width + k + wstart]; - } - } - } - } - for (int l = 0; l < 9; ++l) { - result += fake_input[l] * filter1[l]; - } - output_data[ph * output_width + pw] = - newscale_data[c] * result + newbias_data[c]; - - if (if_relu) { - output_data[ph * output_width + pw] = - output_data[ph * output_width + pw] < 0 - ? 0 - : output_data[ph * output_width + pw]; - } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + output_ptr0 += output_w_remain; + } + // pad right + if (padding_w) { + float32x2_t row0 = vld1_f32(input_ptr0); + float32x2_t row1 = vld1_f32(input_ptr1); + float32x2_t row2 = vld1_f32(input_ptr2); + float32x2_t zero = vdup_n_f32(0.f); + float32x2_t acc; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0.f; } else { - const float32x4_t data1 = vld1q_f32(pos1); - const float32x4_t data2 = vld1q_f32(pos2); - const float32x4_t data3 = vld1q_f32(pos3); - - const float32x4_t v_filter1 = vld1q_f32(filter1); - const float32x4_t v_filter2 = vld1q_f32(filter2); - const float32x4_t v_filter3 = vld1q_f32(filter3); - float32x4_t mula = vmulq_f32(data1, v_filter1); - mula = vmlaq_f32(mula, data2, v_filter2); - mula = vmlaq_f32(mula, data3, v_filter3); - float32x2_t res = vpadd_f32( - vget_high_f32(vsetq_lane_f32(0, mula, 3)), vget_low_f32(mula)); - res = vpadd_f32(res, res); - output_data[ph * output_width + pw] = - vget_lane_f32(res, 0) * newscale_data[c] + newbias_data[c]; - - if (if_relu) { - output_data[ph * output_width + pw] = - output_data[ph * output_width + pw] < 0 - ? 0 - : output_data[ph * output_width + pw]; - } + acc = vmul_f32(row0, vget_low_f32(_ker[0])); + acc = vmla_f32(acc, row1, vget_low_f32(_ker[1])); + acc = vmla_f32(acc, row2, vget_low_f32(_ker[2])); + float32x2_t sum = vpadd_f32(acc, acc); + vst1_lane_f32(output_ptr0, sum, 0); + row0 = vext_f32(row0, zero, 1); + row1 = vext_f32(row1, zero, 1); + row2 = vext_f32(row2, zero, 1); } + output_ptr0++; } } - input_data += input_channel_stride; - output_data += output_channel_stride; - filter_data += filter_channel_stride; } - input_data += input_batch_stride; - output_data += output_batch_stride; + // pad bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<1, 1>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } } -#endif } -void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu) { -#if __ARM_NEON - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->mutable_data(); - const float *bias_data; - if (if_bias) { - bias_data = bias->data(); - } - - const int in_h = static_cast(input->dims()[2]); - const int in_w = static_cast(input->dims()[3]); - const int out_h = static_cast(output->dims()[2]); - const int out_w = static_cast(output->dims()[3]); - const int out_l = out_h; - const int in_l = in_h; - const int inhxw = in_h * in_w; - const int outhxw = out_h * out_w; - /// todo : fix if_pad when w != h - const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; - const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0; - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const float *input_row_ptr; - float *output_row_ptr; - - const int w_times = (out_w - 2) / 3; - - float32x4_t vbias = vdupq_n_f32(0.0); - - float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; - float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; - int out2in_mid; - float32x4_t zero = vdupq_n_f32(0.0); - for (int b = batch_size; b > 0; --b) { - const float *filter_data_tmp = filter_data; - for (int j = 0; j < c; ++j) { - auto output_data_tmp = output_data + j * out_h * out_w; - auto input_data_tmp = input_data + j * in_h * in_w; - auto input_const = input_data_tmp; - - if (if_bias) { - vbias = vdupq_n_f32(bias_data[j]); - } - - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - int h_mid = 0; - - for (; h_mid < out_h - 1; h_mid++) { - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - if (h_mid == 0) { - elewise_res1 = zero; - elewise_res0 = zero; - elewise_res2 = zero; +template <> +void DepthwiseConv3x3S2(const framework::Tensor &input, + const framework::Tensor &filter, + const std::vector &paddings, + framework::Tensor *output) { + const float *input_data = input.data(); + const float *filter_data = filter.data(); + float *out_data = output->mutable_data(); + + const int input_h = input.dims()[2]; + const int input_w = input.dims()[3]; + const int output_h = output->dims()[2]; + const int output_w = output->dims()[3]; + const int padding_h = paddings[0]; + const int padding_w = paddings[1]; + const int image_size = input_h * input_w; + const int out_image_size = output_h * output_w; + const int valid_h_start = (padding_h + 1) / 2; + const int valid_h_end = (input_h + padding_h - 1) / 2; + const int valid_h = valid_h_end - valid_h_start; + const int valid_w_start = (padding_w + 1) / 2; + const int valid_w_end = (input_w + padding_w - 1) / 2; + const int valid_w = valid_w_end - valid_w_start; + const int input_w_start = 2 * valid_w_start - padding_w; + + #pragma omp parallel for + for (int g = 0; g < input.dims()[1]; ++g) { + const float *input_ptr = input_data + g * image_size; + const float *filter_ptr = filter_data + g * 9; + float *output_ptr = out_data + g * out_image_size; + + const float *filter_ptr0 = filter_ptr; + const float *filter_ptr1 = filter_ptr0 + 3; + const float *filter_ptr2 = filter_ptr1 + 3; + float32x4_t _ker[3]; + _ker[0] = vld1q_f32(filter_ptr0); + _ker[1] = vld1q_f32(filter_ptr1); + _ker[2] = vld1q_f32(filter_ptr2); + + // pad top + for (int h = 0; h < valid_h_start; ++h) { + DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } + // valid 2x4 + int output_w_tiles = valid_w / 4; + int output_w_remain = valid_w - output_w_tiles * 4; + for (int h = valid_h_start; h < valid_h_end - 1; h += 2) { + const float *input_ptr0 = input_ptr + (2 * h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + const float *input_ptr3 = input_ptr2 + input_w; + const float *input_ptr4 = input_ptr3 + input_w; + float *output_ptr0 = output_ptr + h * output_w; + float *output_ptr1 = output_ptr0 + output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0; + output_ptr1[w] = 0; } else { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - } - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vaddq_f32(res3, vbias); - if (if_relu) { - res3 = vmaxq_f32(res3, zero); - } - vst1q_f32(output_row_ptr, res3); - - input_row_ptr += 6; - output_row_ptr += 3; - } - } - clock(); - - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - if (!if_pad_b) { - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - } - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vaddq_f32(res3, vbias); - if (if_relu) { - res3 = vmaxq_f32(res3, zero); - } - - if ((w4 != w_times)) { - vst1q_f32(output_row_ptr, res3); - } else { - if (out_w - 2 - w_times * 3 == 1) { - vst1q_lane_f32(output_row_ptr, res3, 0); - } else if (out_w - 2 - w_times * 3 == 2) { - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); + float32x4_t row0 = vld1q_f32(input_ptr0 - padding); + float32x4_t row1 = vld1q_f32(input_ptr1 - padding); + float32x4_t row2 = vld1q_f32(input_ptr2 - padding); + float32x4_t row3 = vld1q_f32(input_ptr3 - padding); + float32x4_t row4 = vld1q_f32(input_ptr4 - padding); + float32x4_t acc0 = vmulq_f32(row0, _ker[0]); + float32x4_t acc1 = vmulq_f32(row2, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc1 = vmlaq_f32(acc1, row4, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 2); + float sum1 = vgetq_lane_f32(acc1, 2); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + sum1 += vgetq_lane_f32(acc1, 1); + } + output_ptr0[w] = sum0; + output_ptr1[w] = sum1; } } - input_row_ptr += 6; - output_row_ptr += 3; + input_ptr0 += input_w_start; + input_ptr1 += input_w_start; + input_ptr2 += input_w_start; + input_ptr3 += input_w_start; + input_ptr4 += input_w_start; + output_ptr0 += valid_w_start; + output_ptr1 += valid_w_start; } - - // leftTop, rightTop, leftBottom, rightBottom - int lt = 0; - int rt = out_w - 1; - int lb = out_w * (out_h - 1); - int rb = out_h * out_w - 1; - - output_data_tmp[lt] = input_const[0] * w11 + input_const[1] * w12 + - input_const[in_w] * w21 + - input_const[in_w + 1] * w22; - - out2in_mid = (out_w - 1) * 2; - output_data_tmp[rt] = - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - - out2in_mid = (out_h - 1) * 2 * in_w; - - output_data_tmp[lb] = - w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + - (1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; - - output_data_tmp[rb] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - (1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w]) + - (1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1]) + - (1 - if_pad_r) * (1 - if_pad_b) * w22 * - input_const[out2in_mid + in_w + 1]; - if (if_bias) { - output_data_tmp[lt] += bias_data[j]; - output_data_tmp[rt] += bias_data[j]; - output_data_tmp[lb] += bias_data[j]; - output_data_tmp[rb] += bias_data[j]; + // valid + float32x4_t _result0, _result1, _ext; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr2); + _row1 = vld2q_f32(input_ptr3); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + _result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr3[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr4); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr4[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + vst1q_f32(output_ptr1, _result1); + + input_ptr0 += 8; + input_ptr1 += 8; + input_ptr2 += 8; + input_ptr3 += 8; + input_ptr4 += 8; + output_ptr0 += 4; + output_ptr1 += 4; } - if (if_relu) { - output_data_tmp[lt] = output_data_tmp[lt] < 0 ? 0 : output_data_tmp[lt]; - output_data_tmp[rt] = output_data_tmp[rt] < 0 ? 0 : output_data_tmp[rt]; - output_data_tmp[lb] = output_data_tmp[lb] < 0 ? 0 : output_data_tmp[lb]; - output_data_tmp[rb] = output_data_tmp[rb] < 0 ? 0 : output_data_tmp[rb]; - } - for (int i = 1; i < out_h - 1; i++) { - out2in_mid = i * 2 * in_w; - int left = i * out_w; - output_data_tmp[left] = w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + - w12 * input_const[out2in_mid + 1] + - w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]; - - out2in_mid = i * 2 * in_w + (out_w - 1) * 2; - int right = i * out_w + out_w - 1; - output_data_tmp[right] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - if (if_bias) { - output_data_tmp[left] += bias_data[j]; - output_data_tmp[right] += bias_data[j]; - } - if (if_relu) { - output_data_tmp[left] = - output_data_tmp[left] < 0 ? 0 : output_data_tmp[left]; - output_data_tmp[right] = - output_data_tmp[right] < 0 ? 0 : output_data_tmp[right]; + // remain w + if (output_w_remain > 0) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr2); + _row1 = vld2q_f32(input_ptr3); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + _result1 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr3[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[1]), 0); + + _row0 = vld2q_f32(input_ptr4); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr4[8], _ext, 3); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[0], vget_low_f32(_ker[2]), 0); + _result1 = + vmlaq_lane_f32(_result1, _row0.val[1], vget_low_f32(_ker[2]), 1); + _result1 = vmlaq_lane_f32(_result1, _ext, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + vst1q_lane_f32(output_ptr1 + 2, _result1, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + vst1_f32(output_ptr1, vget_low_f32(_result1)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + vst1q_lane_f32(output_ptr1, _result1, 0); + break; } + input_ptr0 += output_w_remain * 2; + input_ptr1 += output_w_remain * 2; + input_ptr2 += output_w_remain * 2; + input_ptr3 += output_w_remain * 2; + input_ptr4 += output_w_remain * 2; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; } - filter_data_tmp += 9; - } - input_data += inhxw * c; - output_data += outhxw * c; - } -#endif -} - -void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu) { -#if __ARM_NEON - // #ifdef _OPENMP - // const float *newscale_data = new_scale->data(); - // const float *newbias_data = new_bias->data(); - // - // const int batch_size = static_cast(input->dims()[0]); - // const int input_channel = static_cast(input->dims()[1]); - // - // const int input_height = static_cast(input->dims()[2]); - // const int input_width = static_cast(input->dims()[3]); - // const int output_height = static_cast(output->dims()[2]); - // const int output_width = static_cast(output->dims()[3]); - // const int inhxw = input_height * input_width; - // const int outhxw = output_height * output_width; - // - // float32x4_t zero = vdupq_n_f32(0.0); - // for (int b = 0; b < batch_size; b++) { - // #pragma omp parallel for - // for (int c = 0; c < input_channel; c++) { - // const float *filter_data = filter->data() + c * 9; - // const float *input_data = input->data() + c * inhxw; - // float *output_data = output->data() + c * outhxw; - // float32x4_t vnewbias = vdupq_n_f32(newbias_data[c]); - // float32x4_t vnewscale = vdupq_n_f32(newscale_data[c]); - // - // float w00 = filter_data[0]; - // float w01 = filter_data[1]; - // float w02 = filter_data[2]; - // float w10 = filter_data[3]; - // float w11 = filter_data[4]; - // float w12 = filter_data[5]; - // float w20 = filter_data[6]; - // float w21 = filter_data[7]; - // float w22 = filter_data[8]; - // - // int m; - // for (m = 1; m < output_width - 2; m = m + 3) { - // float *output_ptr = output_data + m; - // float32x4x2_t input_buff_mid{}, input_buff_bottom{}; - // float32x4_t in0, in1, in2, in3, tmp0, tmp1, tmp2, tmp3, out0; - // input_buff_mid = vld2q_f32(input_data + (2 * m - 1)); - // input_buff_bottom = vld2q_f32(input_data + input_width + (2 * m - - // 1)); - // - // in0 = input_buff_mid.val[0]; - // tmp0 = input_buff_mid.val[1]; - // tmp1 = vextq_f32(in0, zero, 1); - // - // in2 = input_buff_bottom.val[0]; - // tmp2 = input_buff_bottom.val[1]; - // tmp3 = vextq_f32(in2, zero, 1); - // - // out0 = vmulq_n_f32(in0, w10); - // out0 = vmlaq_n_f32(out0, tmp0, w11); - // out0 = vmlaq_n_f32(out0, tmp1, w12); - // out0 = vmlaq_n_f32(out0, in2, w20); - // out0 = vmlaq_n_f32(out0, tmp2, w21); - // out0 = vmlaq_n_f32(out0, tmp3, w22); - // out0 = vmlaq_f32(vnewbias, vnewscale, out0); - // if (if_relu) { - // out0 = vmaxq_f32(out0, zero); - // } - // vst1q_lane_f32(output_ptr, out0, 0); - // vst1q_lane_f32(output_ptr + 1, out0, 1); - // vst1q_lane_f32(output_ptr + 2, out0, 2); - // } - // for (m = 1; m < output_width - 2; m += 3) { - // } - // for (int j = m; j < output_width; j++) { - // output_data[j] = input_data[2 * j - 1] * w10 + input_data[2 * j] * - // w11 + - // input_data[2 * j + 1] * w12 + - // input_data[2 * j - 1 + input_width] * w20 + - // input_data[2 * j + input_width] * w21 + - // input_data[2 * j + 1 + input_width] * w22; - // output_data[j] = newscale_data[c] * output_data[j] + - // newbias_data[c]; if (if_relu) { - // output_data[j] = output_data[j] < 0 ? 0 : output_data[j]; - // } - // } - // - // for (int i = 1; i < output_height; i += 1) { - // for (int m = 1; m < output_width - 2; m += 3) { - // float *output_ptr = output_data + i * output_width + m; - // float32x4x2_t input_buff_top{}, input_buff_mid{}, - // input_buff_bottom{}; float32x4_t in0, in1, in2, in3, in4, in5, - // tmp0, tmp1, tmp2, tmp3, - // tmp4, tmp5, out0; - // input_buff_top = - // vld2q_f32(input_data + (2 * i - 1) * input_width + (2 * m - - // 1)); - // input_buff_mid = - // vld2q_f32(input_data + (2 * i) * input_width + (2 * m - 1)); - // input_buff_bottom = - // vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m - - // 1)); - // - // in0 = input_buff_top.val[0]; - // tmp0 = input_buff_top.val[1]; - // tmp1 = vextq_f32(in0, zero, 1); - // - // in2 = input_buff_mid.val[0]; - // tmp2 = input_buff_mid.val[1]; - // tmp3 = vextq_f32(in2, zero, 1); - // - // in4 = input_buff_bottom.val[0]; - // tmp4 = input_buff_bottom.val[1]; - // tmp5 = vextq_f32(in4, zero, 1); - // - // out0 = vmulq_n_f32(in0, w00); - // out0 = vmlaq_n_f32(out0, tmp0, w01); - // out0 = vmlaq_n_f32(out0, tmp1, w02); - // out0 = vmlaq_n_f32(out0, in2, w10); - // out0 = vmlaq_n_f32(out0, tmp2, w11); - // out0 = vmlaq_n_f32(out0, tmp3, w12); - // out0 = vmlaq_n_f32(out0, in4, w20); - // out0 = vmlaq_n_f32(out0, tmp4, w21); - // out0 = vmlaq_n_f32(out0, tmp5, w22); - // out0 = vmlaq_f32(vnewbias, vnewscale, out0); - // if (if_relu) { - // out0 = vmaxq_f32(out0, zero); - // } - // vst1q_lane_f32(output_ptr, out0, 0); - // vst1q_lane_f32(output_ptr + 1, out0, 1); - // vst1q_lane_f32(output_ptr + 2, out0, 2); - // } - // int m; - // for (m = 1; m < output_width - 2; m += 3) { - // } - // for (int j = m; j < output_width; j++) { - // output_data[i * output_width + j] = - // input_data[(2 * i - 1) * input_width + 2 * j - 1] * w00 + - // input_data[(2 * i - 1) * input_width + 2 * j] * w01 + - // input_data[(2 * i - 1) * input_width + 2 * j + 1] * w02 + - // input_data[(2 * i) * input_width + 2 * j - 1] * w10 + - // input_data[(2 * i) * input_width + 2 * j] * w11 + - // input_data[(2 * i) * input_width + 2 * j + 1] * w12 + - // input_data[(2 * i + 1) * input_width + 2 * j - 1] * w20 + - // input_data[(2 * i + 1) * input_width + 2 * j] * w21 + - // input_data[(2 * i + 1) * input_width + 2 * j + 1] * w22; - // output_data[i * output_width + j] = - // newscale_data[c] * output_data[i * output_width + j] + - // newbias_data[c]; - // if (if_relu) { - // output_data[i * output_width + j] = - // output_data[i * output_width + j] < 0 - // ? 0 - // : output_data[i * output_width + j]; - // } - // } - // } - // output_data[0] = input_data[0] * w11 + input_data[1] * w12 + - // input_data[input_height] * w21 + - // input_data[input_height + 1] * w22; - // - // output_data[0] = newscale_data[c] * output_data[0] + newbias_data[c]; - // if (if_relu) { - // output_data[0] = output_data[0] < 0 ? 0 : output_data[0]; - // } - // for (int i = 1; i < output_height; i++) { - // output_data[i * output_width] = - // input_data[(2 * i - 1) * input_width] * w01 + - // input_data[(2 * i - 1) * input_width + 1] * w02 + - // input_data[(2 * i) * input_width] * w11 + - // input_data[(2 * i) * input_width + 1] * w12 + - // input_data[(2 * i + 1) * input_width] * w21 + - // input_data[(2 * i + 1) * input_width + 1] * w22; - // - // output_data[i * output_width] = - // newscale_data[c] * output_data[i * output_width] + - // newbias_data[c]; - // if (if_relu) { - // output_data[i * output_width] = output_data[i * output_width] < 0 - // ? 0 - // : output_data[i * - // output_width]; - // } - // } - // } - // } - // - // #else - - const float *input_data = input->data(); - const float *filter_data = filter->data(); - float *output_data = output->mutable_data(); - const float *newscale_data = new_scale->data(); - const float *newbias_data = new_bias->data(); - - const int in_h = static_cast(input->dims()[2]); - const int in_w = static_cast(input->dims()[3]); - const int out_h = static_cast(output->dims()[2]); - const int out_w = static_cast(output->dims()[3]); - // const int out_l = out_h; - // const int in_l = in_h; - const int inhxw = in_h * in_w; - const int outhxw = out_h * out_w; - /// todo : fix if_pad when w != h - const int if_pad_r = in_w - 1 == (out_w - 1) * 2 ? 1 : 0; - const int if_pad_b = in_h - 1 == (out_h - 1) * 2 ? 1 : 0; - const int batch_size = static_cast(input->dims()[0]); - const int c = static_cast(input->dims()[1]); - const int w_times = (out_w - 2) / 3; - float32x4_t zero = vdupq_n_f32(0.0); - for (int b = batch_size; b > 0; --b) { -#pragma omp parallel for - for (int j = 0; j < c; j++) { - const float *input_row_ptr; - float *output_row_ptr; - float32x4x2_t input_buff_mid{}, input_buff_bottom[w_times + 1]; - float32x4_t elewise_res0, elewise_res1, elewise_res2, res3; - int out2in_mid; - float32x4_t vnewbias = vdupq_n_f32(0.0); - float32x4_t vnewscale = vdupq_n_f32(1.0); - auto output_data_tmp = output_data + j * out_h * out_w; - auto input_data_tmp = input_data + j * in_h * in_w; - auto input_const = input_data_tmp; - const float *filter_data_tmp = filter_data + 9 * j; - vnewbias = vdupq_n_f32(newbias_data[j]); - vnewscale = vdupq_n_f32(newscale_data[j]); - - float w00 = filter_data_tmp[0]; - float w01 = filter_data_tmp[1]; - float w02 = filter_data_tmp[2]; - float w10 = filter_data_tmp[3]; - float w11 = filter_data_tmp[4]; - float w12 = filter_data_tmp[5]; - float w20 = filter_data_tmp[6]; - float w21 = filter_data_tmp[7]; - float w22 = filter_data_tmp[8]; - - int h_mid = 0; - - for (; h_mid < out_h - 1; h_mid++) { - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - if (h_mid == 0) { - elewise_res1 = zero; - elewise_res0 = zero; - elewise_res2 = zero; + // pad right + if (padding_w > 0) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t row3 = vld1q_f32(input_ptr3); + float32x4_t row4 = vld1q_f32(input_ptr4); + float32x4_t acc0, acc1; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + *output_ptr1 = 0; } else { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - } - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vmlaq_f32(vnewbias, vnewscale, res3); - - if (if_relu) { - res3 = vmaxq_f32(res3, zero); + acc0 = vmulq_f32(row0, _ker[0]); + acc1 = vmulq_f32(row2, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc1 = vmlaq_f32(acc1, row3, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + acc1 = vmlaq_f32(acc1, row4, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 0); + float sum1 = vgetq_lane_f32(acc1, 0); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + sum1 += vgetq_lane_f32(acc1, 1); + } + *output_ptr0 = sum0; + *output_ptr1 = sum1; } - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); - vst1q_lane_f32(output_row_ptr + 2, res3, 2); - - input_row_ptr += 6; - output_row_ptr += 3; + output_ptr0++; + output_ptr1++; } } - clock(); - - input_row_ptr = input_data_tmp + 1 + h_mid * 2 * in_w; - output_row_ptr = output_data_tmp + 1 + h_mid * out_w; - - for (int w4 = 0; w4 < w_times + 1; w4++) { - elewise_res1 = vmulq_n_f32(input_buff_bottom[w4].val[1], w01); - elewise_res0 = vmulq_n_f32(input_buff_bottom[w4].val[0], w00); - elewise_res2 = vmulq_n_f32(input_buff_bottom[w4].val[0], w02); - - input_buff_mid = vld2q_f32(input_row_ptr); - input_buff_bottom[w4] = vld2q_f32(input_row_ptr + in_w); - - elewise_res1 = vmlaq_n_f32(elewise_res1, input_buff_mid.val[1], w11); - elewise_res0 = vmlaq_n_f32(elewise_res0, input_buff_mid.val[0], w10); - elewise_res2 = vmlaq_n_f32(elewise_res2, input_buff_mid.val[0], w12); - - if (!if_pad_b) { - elewise_res1 = - vmlaq_n_f32(elewise_res1, input_buff_bottom[w4].val[1], w21); - elewise_res0 = - vmlaq_n_f32(elewise_res0, input_buff_bottom[w4].val[0], w20); - elewise_res2 = - vmlaq_n_f32(elewise_res2, input_buff_bottom[w4].val[0], w22); - } - res3 = vaddq_f32(vextq_f32(elewise_res2, zero, 1), - vaddq_f32(elewise_res0, elewise_res1)); - res3 = vmlaq_f32(vnewbias, vnewscale, res3); - - if (if_relu) { - res3 = vmaxq_f32(res3, zero); - } - if ((w4 != w_times)) { - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); - vst1q_lane_f32(output_row_ptr + 2, res3, 2); - } else { - if (out_w - 2 - w_times * 3 == 1) { - vst1q_lane_f32(output_row_ptr, res3, 0); - } else if (out_w - 2 - w_times * 3 == 2) { - vst1q_lane_f32(output_row_ptr, res3, 0); - vst1q_lane_f32(output_row_ptr + 1, res3, 1); + } + // remain height + int start_h = valid_h_start + (valid_h & 0xfffe); + if (start_h < valid_h_end) { + const float *input_ptr0 = input_ptr + (2 * start_h - padding_h) * input_w; + const float *input_ptr1 = input_ptr0 + input_w; + const float *input_ptr2 = input_ptr1 + input_w; + float *output_ptr0 = output_ptr + start_h * output_w; + // pad left + if (padding_w) { + for (int w = valid_w_start - 1; w >= 0; --w) { + int padding = padding_w - (w << 1); + if (padding >= 3) { + output_ptr0[w] = 0; + } else { + float32x4_t row0 = vld1q_f32(input_ptr0 - padding); + float32x4_t row1 = vld1q_f32(input_ptr1 - padding); + float32x4_t row2 = vld1q_f32(input_ptr2 - padding); + float32x4_t acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 2); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + } + output_ptr0[w] = sum0; } } - input_row_ptr += 6; - output_row_ptr += 3; + input_ptr0 += input_w_start; + input_ptr1 += input_w_start; + input_ptr2 += input_w_start; + output_ptr0 += valid_w_start; } - - output_data_tmp[0] = input_const[0] * w11 + input_const[1] * w12 + - input_const[in_w] * w21 + - input_const[in_w + 1] * w22; - - out2in_mid = (out_w - 1) * 2; - output_data_tmp[out_w - 1] = - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - - out2in_mid = (out_h - 1) * 2 * in_w; - - output_data_tmp[out_w * (out_h - 1)] = - w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + w12 * input_const[out2in_mid + 1] + - (1 - if_pad_b) * (w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]); - out2in_mid = (out_h - 1) * 2 * in_w + (out_w - 1) * 2; - - output_data_tmp[out_h * out_w - 1] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - (1 - if_pad_r) * (w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w]) + - (1 - if_pad_b) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1]) + - (1 - if_pad_r) * (1 - if_pad_b) * w22 * - input_const[out2in_mid + in_w + 1]; - output_data_tmp[0] = - output_data_tmp[0] * newscale_data[j] + newbias_data[j]; - output_data_tmp[out_w - 1] = - output_data_tmp[out_w - 1] * newscale_data[j] + newbias_data[j]; - output_data_tmp[out_w * (out_h - 1)] = - output_data_tmp[out_w * (out_h - 1)] * newscale_data[j] + - newbias_data[j]; - output_data_tmp[out_h * out_w - 1] = - output_data_tmp[out_h * out_w - 1] * newscale_data[j] + - newbias_data[j]; - if (if_relu) { - output_data_tmp[0] = output_data_tmp[0] < 0 ? 0 : output_data_tmp[0]; - output_data_tmp[out_w - 1] = - output_data_tmp[out_w - 1] < 0 ? 0 : output_data_tmp[out_w - 1]; - output_data_tmp[out_w * (out_h - 1)] = - output_data_tmp[out_w * (out_h - 1)] < 0 - ? 0 - : output_data_tmp[out_w * (out_h - 1)]; - output_data_tmp[out_h * out_w - 1] = - output_data_tmp[out_h * out_w - 1] < 0 - ? 0 - : output_data_tmp[out_h * out_w - 1]; + // valid + float32x4_t _result0, _ext; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + float32x4x2_t _row2 = vld2q_f32(input_ptr2); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _ext = vextq_f32(_row2.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + + vst1q_f32(output_ptr0, _result0); + + input_ptr0 += 8; + input_ptr1 += 8; + input_ptr2 += 8; + output_ptr0 += 4; } - for (int i = 1; i < out_h - 1; i++) { - out2in_mid = i * 2 * in_w; - output_data_tmp[i * out_w] = w01 * input_const[out2in_mid - in_w] + - w02 * input_const[out2in_mid - in_w + 1] + - w11 * input_const[out2in_mid] + - w12 * input_const[out2in_mid + 1] + - w21 * input_const[out2in_mid + in_w] + - w22 * input_const[out2in_mid + in_w + 1]; - - out2in_mid = i * 2 * in_w + (out_w - 1) * 2; - output_data_tmp[i * out_w + out_w - 1] = - w00 * input_const[out2in_mid - in_w - 1] + - w01 * input_const[out2in_mid - in_w] + - w10 * input_const[out2in_mid - 1] + w11 * input_const[out2in_mid] + - w20 * input_const[out2in_mid + in_w - 1] + - w21 * input_const[out2in_mid + in_w] + - (1 - if_pad_r) * (w02 * input_const[out2in_mid - in_w + 1] + - w12 * input_const[out2in_mid + 1] + - w22 * input_const[out2in_mid + in_w + 1]); - output_data_tmp[i * out_w] = - output_data_tmp[i * out_w] * newscale_data[j] + newbias_data[j]; - output_data_tmp[i * out_w + out_w - 1] = - output_data_tmp[i * out_w + out_w - 1] * newscale_data[j] + - newbias_data[j]; - if (if_relu) { - output_data_tmp[i * out_w] = - output_data_tmp[i * out_w] < 0 ? 0 : output_data_tmp[i * out_w]; - output_data_tmp[i * out_w + out_w - 1] = - output_data_tmp[i * out_w + out_w - 1] < 0 - ? 0 - : output_data_tmp[i * out_w + out_w - 1]; + // remain w + if (output_w_remain > 0) { + float32x4x2_t _row0 = vld2q_f32(input_ptr0); + float32x4x2_t _row1 = vld2q_f32(input_ptr1); + float32x4x2_t _row2 = vld2q_f32(input_ptr2); + + _ext = vextq_f32(_row0.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr0[8], _ext, 3); + _result0 = vmulq_lane_f32(_row0.val[0], vget_low_f32(_ker[0]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row0.val[1], vget_low_f32(_ker[0]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[0]), 0); + + _ext = vextq_f32(_row1.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr1[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[0], vget_low_f32(_ker[1]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row1.val[1], vget_low_f32(_ker[1]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[1]), 0); + + _ext = vextq_f32(_row2.val[0], _ext, 1); + _ext = vsetq_lane_f32(input_ptr2[8], _ext, 3); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[0], vget_low_f32(_ker[2]), 0); + _result0 = + vmlaq_lane_f32(_result0, _row2.val[1], vget_low_f32(_ker[2]), 1); + _result0 = vmlaq_lane_f32(_result0, _ext, vget_high_f32(_ker[2]), 0); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _result0, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_result0)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _result0, 0); + break; } + input_ptr0 += output_w_remain * 2; + input_ptr1 += output_w_remain * 2; + input_ptr2 += output_w_remain * 2; + output_ptr0 += output_w_remain; } - } - input_data += inhxw * c; - output_data += outhxw * c; - } -// #endif -#endif -} - -void DepthwiseConv3x3s2p0(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu) { -#if __ARM_NEON - const int batch_size = static_cast(input->dims()[0]); - const int input_channel = static_cast(input->dims()[1]); - - const int input_height = static_cast(input->dims()[2]); - const int input_width = static_cast(input->dims()[3]); - const int output_height = static_cast(output->dims()[2]); - const int output_width = static_cast(output->dims()[3]); - const int inhxw = input_height * input_width; - const int outhxw = output_height * output_width; - output->mutable_data(); - - float32x4_t zero = vdupq_n_f32(0.0); - for (int b = 0; b < batch_size; b++) { -#pragma omp parallel for - for (int c = 0; c < input_channel; c++) { - const float *filter_data = filter->data() + c * 9; - const float *input_data = input->data() + c * inhxw; - const float *bias_data; - float32x4_t biasv; - if (if_bias) { - bias_data = bias->data() + c; - biasv = vld1q_dup_f32(bias_data); - } - float *output_data = output->data() + c * outhxw; - float w00 = filter_data[0]; - float w01 = filter_data[1]; - float w02 = filter_data[2]; - float w10 = filter_data[3]; - float w11 = filter_data[4]; - float w12 = filter_data[5]; - float w20 = filter_data[6]; - float w21 = filter_data[7]; - float w22 = filter_data[8]; - for (int i = 0; i < output_height; i += 1) { - for (int m = 0; m < output_width - 2; m += 3) { - float *output_ptr = output_data + i * output_width + m; - float32x4x2_t input_buff_top{}, input_buff_mid{}, input_buff_bottom{}; - float32x4_t in0, in1, in2, in3, in4, in5, tmp0, tmp1, tmp2, tmp3, - tmp4, tmp5, out0; - input_buff_top = - vld2q_f32(input_data + (2 * i) * input_width + (2 * m)); - input_buff_mid = - vld2q_f32(input_data + (2 * i + 1) * input_width + (2 * m)); - input_buff_bottom = - vld2q_f32(input_data + (2 * i + 2) * input_width + (2 * m)); - - in0 = input_buff_top.val[0]; - tmp0 = input_buff_top.val[1]; - tmp1 = vextq_f32(in0, zero, 1); - - in2 = input_buff_mid.val[0]; - tmp2 = input_buff_mid.val[1]; - tmp3 = vextq_f32(in2, zero, 1); - - in4 = input_buff_bottom.val[0]; - tmp4 = input_buff_bottom.val[1]; - tmp5 = vextq_f32(in4, zero, 1); - - out0 = vmulq_n_f32(in0, w00); - out0 = vmlaq_n_f32(out0, tmp0, w01); - out0 = vmlaq_n_f32(out0, tmp1, w02); - out0 = vmlaq_n_f32(out0, in2, w10); - out0 = vmlaq_n_f32(out0, tmp2, w11); - out0 = vmlaq_n_f32(out0, tmp3, w12); - out0 = vmlaq_n_f32(out0, in4, w20); - out0 = vmlaq_n_f32(out0, tmp4, w21); - out0 = vmlaq_n_f32(out0, tmp5, w22); - if (if_bias) { - out0 = vaddq_f32(out0, biasv); - } - if (if_relu) { - out0 = vmaxq_f32(out0, zero); - } - vst1q_lane_f32(output_ptr, out0, 0); - vst1q_lane_f32(output_ptr + 1, out0, 1); - vst1q_lane_f32(output_ptr + 2, out0, 2); - } - int m; - for (m = 0; m < output_width - 2; m += 3) { - } - for (int j = m; j < output_width; j++) { - int index = i * output_width + j; - output_data[index] = - input_data[(2 * i) * input_width + 2 * j] * w00 + - input_data[(2 * i) * input_width + 2 * j + 1] * w01 + - input_data[(2 * i) * input_width + 2 * j + 2] * w02 + - input_data[(2 * i + 1) * input_width + 2 * j] * w10 + - input_data[(2 * i + 1) * input_width + 2 * j + 1] * w11 + - input_data[(2 * i + 1) * input_width + 2 * j + 2] * w12 + - input_data[(2 * i + 2) * input_width + 2 * j] * w20 + - input_data[(2 * i + 2) * input_width + 2 * j + 1] * w21 + - input_data[(2 * i + 2) * input_width + 2 * j + 2] * w22; - if (if_bias) { - output_data[index] += *bias_data; - } - if (if_relu) { - output_data[index] = - output_data[index] < 0 ? 0 : output_data[index]; + // pad right + if (padding_w) { + float32x4_t row0 = vld1q_f32(input_ptr0); + float32x4_t row1 = vld1q_f32(input_ptr1); + float32x4_t row2 = vld1q_f32(input_ptr2); + float32x4_t acc0; + for (int w = valid_w_end; w < output_w; ++w) { + int padding = 2 * w + 3 - (padding_w + input_w); + if (padding >= 3) { + *output_ptr0 = 0; + } else { + acc0 = vmulq_f32(row0, _ker[0]); + acc0 = vmlaq_f32(acc0, row1, _ker[1]); + acc0 = vmlaq_f32(acc0, row2, _ker[2]); + float sum0 = vgetq_lane_f32(acc0, 0); + if (padding == 1) { + sum0 += vgetq_lane_f32(acc0, 1); + } + *output_ptr0 = sum0; } + output_ptr0++; } } } + // pad bottom + for (int h = valid_h_end; h < output_h; ++h) { + DepthwiseConv3x3NormalRow<2, 2>(input_ptr, filter_ptr, h, input_h, + input_w, padding_h, padding_w, output_w, + output_ptr, _ker); + } } - -#endif } } // namespace math } // namespace operators } // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/depthwise_conv3x3.h b/src/operators/math/depthwise_conv3x3.h index fde5d878c8a62c167af7a3359a991f77d3d3fce5..1f145c4f94bf2061fb9db74aec84684387809854 100644 --- a/src/operators/math/depthwise_conv3x3.h +++ b/src/operators/math/depthwise_conv3x3.h @@ -17,54 +17,11 @@ limitations under the License. */ #include #include #include "framework/tensor.h" -#include "operators/math/conv_func.h" namespace paddle_mobile { namespace operators { namespace math { -void DepthwiseConv3x3(const framework::Tensor *input, - const std::vector &strides, - const std::vector &paddings, - const framework::Tensor *filter, framework::Tensor *bias, - framework::Tensor *output, bool if_bias); - -void DepthwiseConv3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu); - -void DepthwiseConvAddBNRelu3x3s1p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu); - -void DepthwiseConvAddBNRelu3x3s2p1(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu); - -void DepthwiseConv3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu); - -void DepthwiseConvAddBNRelu3x3s2p1v2(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, - const framework::Tensor *new_scale, - const framework::Tensor *new_bias, - bool if_relu); - -void DepthwiseConv3x3s2p0(const framework::Tensor *input, - const framework::Tensor *filter, - framework::Tensor *output, framework::Tensor *bias, - bool if_bias, bool if_relu); - // TODO(hjchen2) need to be implemented // template // void DepthwiseConv3x3(const framework::Tensor *input, diff --git a/src/operators/math/depthwise_conv3x3_int8.cpp b/src/operators/math/depthwise_conv3x3_int8.cpp index 91e682c14590a10fc393aaefb5d37c015065fc0a..b8d7939badbfafb0f5c3ee2034320bf817eb5c32 100644 --- a/src/operators/math/depthwise_conv3x3_int8.cpp +++ b/src/operators/math/depthwise_conv3x3_int8.cpp @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#if defined(__ARM_NEON__) && !defined(__aarch64__) +#if defined(__ARM_NEON__) || defined(__ARM_NEON) #include #include "operators/math/depthwise_conv3x3.h" @@ -69,10 +69,8 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, // border left DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) // middle - int remain_start = valid_w_start; -#ifdef __ARM_NEON__ int output_tiles = (valid_w_end - valid_w_start) / 6; - remain_start = valid_w_start + output_tiles * 6; + int remain_start = valid_w_start + output_tiles * 6; int32x4_t _sum0, _sum1; int16x8_t _y[3]; for (int w = 0; w < output_tiles * 6; w += 6) { @@ -94,7 +92,6 @@ inline void DepthwiseConv3x3NormalRow(const int8_t *input, const int8_t *filter, vst1q_s32(output_ptr + output_offset, _sum0); vst1_s32(output_ptr + output_offset + 4, vget_low_s32(_sum1)); } -#endif // __ARM_NEON__ for (int w = remain_start; w < valid_w_end; ++w) { int32_t value = 0; int input_start = -padding_w + w * Stride_w; @@ -215,6 +212,8 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, output_ptr2 += valid_w_start; output_ptr3 += valid_w_start; } +#if __aarch64__ +#else // valid int loop = output_w_tiles; asm volatile( @@ -525,6 +524,7 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w) { int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2))); @@ -618,7 +618,9 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, output_ptr0 += valid_w_start; output_ptr1 += valid_w_start; } - // valid + // valid +#if __aarch64__ +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -804,6 +806,7 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w) { int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2))); @@ -869,7 +872,9 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, } output_ptr0 += valid_w_start; } - // valid + // valid +#if __aarch64__ +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -993,6 +998,7 @@ void DepthwiseConv3x3S1(const framework::Tensor &input, : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w) { int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0 - 2))); @@ -1152,7 +1158,9 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, output_ptr1 += valid_w_start; output_ptr2 += valid_w_start; } - // valid + // valid +#if __aarch64__ +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -1411,6 +1419,7 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w > 0) { int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); @@ -1490,7 +1499,9 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, input_ptr2 += valid_input_w_start; output_ptr0 += valid_w_start; } - // valid + // valid +#if __aarch64__ +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -1608,6 +1619,7 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, : [remain] "r"(output_w_remain), [ker0] "w"(_ker0), [ker1] "w"(_ker1) : "cc", "memory", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w > 0) { int16x4_t row0 = vget_low_s16(vmovl_s8(vld1_s8(input_ptr0))); @@ -1645,4 +1657,4 @@ void DepthwiseConv3x3S2(const framework::Tensor &input, } // namespace operators } // namespace paddle_mobile -#endif +#endif // __ARM_NEON__ diff --git a/src/operators/math/depthwise_conv3x3_int8_arm64.cpp b/src/operators/math/depthwise_conv3x3_int8_arm64.cpp deleted file mode 100644 index e2c01838442b01dee10cd8d85126429277d8c672..0000000000000000000000000000000000000000 --- a/src/operators/math/depthwise_conv3x3_int8_arm64.cpp +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#if defined(__ARM_NEON__) && defined(__aarch64__) - -#include "operators/math/depthwise_conv3x3.h" -#ifdef __ARM_NEON__ -#include -#endif - -namespace paddle_mobile { -namespace operators { -namespace math { - -// template<> -// void DepthwiseConv3x3( -// const framework::Tensor *input, const framework::Tensor *filter, -// const std::vector &strides, framework::Tensor *output) { -// PADDLE_MOBILE_THROW_EXCEPTION( -// "Depthwise conv with generic strides has not been implemented."); -// } - -template <> -void DepthwiseConv3x3S1(const framework::Tensor &input, - const framework::Tensor &filter, - const std::vector &paddings, - framework::Tensor *output) { - PADDLE_MOBILE_THROW_EXCEPTION( - "Depthwise conv3x3 with stride 1 for arm v8 has not been implemented."); -} - -template <> -void DepthwiseConv3x3S2(const framework::Tensor &input, - const framework::Tensor &filter, - const std::vector &paddings, - framework::Tensor *output) { - PADDLE_MOBILE_THROW_EXCEPTION( - "Depthwise conv3x3 with stride 2 for arm v8 has not been implemented."); -} - -} // namespace math -} // namespace operators -} // namespace paddle_mobile - -#endif diff --git a/src/operators/math/depthwise_conv5x5.cpp b/src/operators/math/depthwise_conv5x5.cpp index 792a98659e7b03d4220b7e2ded540782ce880931..99c08c185cc8b28a4159226c2f0502794e0a0c37 100644 --- a/src/operators/math/depthwise_conv5x5.cpp +++ b/src/operators/math/depthwise_conv5x5.cpp @@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#if defined(__ARM_NEON__) && !defined(__aarch64__) +#if defined(__ARM_NEON__) || defined(__ARM_NEON) #include "operators/math/depthwise_conv5x5.h" #include +#include namespace paddle_mobile { namespace operators { @@ -48,7 +49,7 @@ inline void Depth5x5NormalRowLoadInput<2>(const float *input, float32x4_t *y) { y[4] = vextq_f32(y[0], y[0], 2); } -#define DEPTHWISE_CONV_NORMAL_BORDER(start, end) \ +#define DEPTHWISE_CONV5X5_NORMAL_BORDER(start, end) \ for (int w = start; w < end; ++w) { \ const int w_in_start = -padding_w + w * Stride_w; \ const int w_in_end = w_in_start + 5; \ @@ -77,10 +78,14 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, const int h_end = h_in_end < input_h ? h_in_end : input_h; int valid_w_start = (padding_w + Stride_w - 1) / Stride_w; - int valid_w_end = output_w - valid_w_start; + int valid_w_end = (input_w + padding_w - 5) / Stride_w + 1; + if (valid_w_end < valid_w_start) { + valid_w_end = valid_w_start; + } float *output_ptr = output + h_output * output_w; + // border left - DEPTHWISE_CONV_NORMAL_BORDER(0, valid_w_start) + DEPTHWISE_CONV5X5_NORMAL_BORDER(0, valid_w_start) // middle int output_tiles = (valid_w_end - valid_w_start) >> 2; float32x4_t _sum, _x[5]; @@ -120,20 +125,18 @@ inline void DepthwiseConv5x5NormalRow(const float *input, const float *filter, _sum = vmlaq_lane_f32(_sum, _x[4], vget_high_f32(ker[index]), 1); } switch (remain) { - case 1: - vst1_lane_f32(output_ptr0, vget_low_f32(_sum), 0); - break; + case 3: + vst1q_lane_f32(output_ptr0 + 2, _sum, 2); case 2: vst1_f32(output_ptr0, vget_low_f32(_sum)); break; - case 3: - vst1_f32(output_ptr0, vget_low_f32(_sum)); - vst1_lane_f32(output_ptr0 + 2, vget_high_f32(_sum), 0); + case 1: + vst1q_lane_f32(output_ptr0, _sum, 0); break; } } // border right - DEPTHWISE_CONV_NORMAL_BORDER(valid_w_end, output_w) + DEPTHWISE_CONV5X5_NORMAL_BORDER(valid_w_end, output_w) } template <> @@ -144,23 +147,24 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, const float *input_data = input.data(); const float *filter_data = filter.data(); float *out_data = output->mutable_data(); - int input_h = input.dims()[2]; - int input_w = input.dims()[3]; - int output_h = output->dims()[2]; - int output_w = output->dims()[3]; - int padding_h = paddings[0]; - int padding_w = paddings[1]; - int image_size = input_h * input_w; - int out_image_size = output_h * output_w; - int valid_h_start = padding_h; - int valid_h_end = output_h - valid_h_start; - int valid_h = valid_h_end - valid_h_start; - int valid_w_start = padding_w; - int valid_w_end = output_w - valid_w_start; - int valid_w = valid_w_end - valid_w_start; + + const int input_h = input.dims()[2]; + const int input_w = input.dims()[3]; + const int output_h = output->dims()[2]; + const int output_w = output->dims()[3]; + const int padding_h = paddings[0]; + const int padding_w = paddings[1]; + const int image_size = input_h * input_w; + const int out_image_size = output_h * output_w; + const int valid_h_start = padding_h; + const int valid_h_end = output_h - valid_h_start; + const int valid_h = valid_h_end - valid_h_start; + const int valid_w_start = padding_w; + const int valid_w_end = output_w - valid_w_start; + const int valid_w = valid_w_end - valid_w_start; #pragma omp parallel for - for (int g = 0; g < input.dims()[1]; ++g) { + for (int g = 0; g < output->dims()[1]; ++g) { const float *input_ptr = input_data + g * image_size; const float *filter_ptr = filter_data + g * 25; float *output_ptr = out_data + g * out_image_size; @@ -243,7 +247,223 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, output_ptr0 += valid_w_start; output_ptr1 += valid_w_start; } - // valid + // valid +#if __aarch64__ + float32x4_t _q14, _q15; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q15 = vmulq_lane_f32(_q9, vget_low_f32(_ker[5]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[1]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + _q11 = vld1q_f32(input_ptr5); + _q12 = vld1q_f32(input_ptr5 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q15 = vmlaq_lane_f32(_q15, _q7, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q8, vget_high_f32(_ker[2]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q15 = vmlaq_lane_f32(_q15, _q9, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[3]), 1); + + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[4]), 1); + + vst1q_f32(output_ptr0, _q14); + vst1q_f32(output_ptr1, _q15); + + input_ptr0 += 4; + input_ptr1 += 4; + input_ptr2 += 4; + input_ptr3 += 4; + input_ptr4 += 4; + input_ptr5 += 4; + output_ptr0 += 4; + output_ptr1 += 4; + } + // remain w + if (output_w_remain > 0) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q15 = vmulq_lane_f32(_q9, vget_low_f32(_ker[5]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[1]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + _q11 = vld1q_f32(input_ptr5); + _q12 = vld1q_f32(input_ptr5 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q15 = vmlaq_lane_f32(_q15, _q7, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + _q15 = vmlaq_lane_f32(_q15, _q8, vget_high_f32(_ker[2]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q15 = vmlaq_lane_f32(_q15, _q9, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_ker[3]), 1); + + _q15 = vmlaq_lane_f32(_q15, _q11, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q15 = vmlaq_lane_f32(_q15, _q13, vget_high_f32(_ker[4]), 0); + _q15 = vmlaq_lane_f32(_q15, _q12, vget_high_f32(_ker[4]), 1); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _q14, 2); + vst1q_lane_f32(output_ptr1 + 2, _q15, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_q14)); + vst1_f32(output_ptr1, vget_low_f32(_q15)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _q14, 0); + vst1q_lane_f32(output_ptr1, _q15, 0); + break; + } + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + input_ptr4 += output_w_remain; + input_ptr5 += output_w_remain; + output_ptr0 += output_w_remain; + output_ptr1 += output_w_remain; + } +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -443,6 +663,7 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, [kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6]) : "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w) { float32x4_t row0 = vld1q_f32(input_ptr0); @@ -540,7 +761,154 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, } output_ptr0 += valid_w_start; } - // valid + // valid +#if __aarch64__ + float32x4_t _q14; + for (int loop = 0; loop < output_w_tiles; ++loop) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + + vst1q_f32(output_ptr0, _q14); + + input_ptr0 += 4; + input_ptr1 += 4; + input_ptr2 += 4; + input_ptr3 += 4; + input_ptr4 += 4; + output_ptr0 += 4; + } + // remain w + if (output_w_remain > 0) { + float32x4_t _q7 = vld1q_f32(input_ptr0); + float32x4_t _q8 = vld1q_f32(input_ptr0 + 4); + float32x4_t _q9 = vld1q_f32(input_ptr1); + float32x4_t _q10 = vld1q_f32(input_ptr1 + 4); + float32x4_t _q11 = vld1q_f32(input_ptr2); + float32x4_t _q12 = vld1q_f32(input_ptr2 + 4); + + _q14 = vmulq_lane_f32(_q7, vget_low_f32(_ker[5]), 0); + float32x4_t _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[0]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[0]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[0]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[5]), 1); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[1]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[1]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[1]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q11, vget_high_f32(_ker[5]), 0); + _q13 = vextq_f32(_q11, _q12, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 0); + _q13 = vextq_f32(_q11, _q12, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[2]), 1); + _q13 = vextq_f32(_q11, _q12, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[2]), 0); + _q14 = vmlaq_lane_f32(_q14, _q12, vget_high_f32(_ker[2]), 1); + + _q7 = vld1q_f32(input_ptr3); + _q8 = vld1q_f32(input_ptr3 + 4); + _q9 = vld1q_f32(input_ptr4); + _q10 = vld1q_f32(input_ptr4 + 4); + + _q14 = vmlaq_lane_f32(_q14, _q7, vget_high_f32(_ker[5]), 1); + _q13 = vextq_f32(_q7, _q8, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 0); + _q13 = vextq_f32(_q7, _q8, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[3]), 1); + _q13 = vextq_f32(_q7, _q8, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[3]), 0); + _q14 = vmlaq_lane_f32(_q14, _q8, vget_high_f32(_ker[3]), 1); + + _q14 = vmlaq_lane_f32(_q14, _q9, vget_low_f32(_ker[6]), 0); + _q13 = vextq_f32(_q9, _q10, 1); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 0); + _q13 = vextq_f32(_q9, _q10, 2); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_low_f32(_ker[4]), 1); + _q13 = vextq_f32(_q9, _q10, 3); + _q14 = vmlaq_lane_f32(_q14, _q13, vget_high_f32(_ker[4]), 0); + _q14 = vmlaq_lane_f32(_q14, _q10, vget_high_f32(_ker[4]), 1); + + switch (output_w_remain) { + case 3: + vst1q_lane_f32(output_ptr0 + 2, _q14, 2); + case 2: + vst1_f32(output_ptr0, vget_low_f32(_q14)); + break; + case 1: + vst1q_lane_f32(output_ptr0, _q14, 0); + break; + } + + input_ptr0 += output_w_remain; + input_ptr1 += output_w_remain; + input_ptr2 += output_w_remain; + input_ptr3 += output_w_remain; + input_ptr4 += output_w_remain; + output_ptr0 += output_w_remain; + } +#else int loop = output_w_tiles; asm volatile( "cmp %[loop], #0 \n" @@ -676,6 +1044,7 @@ void DepthwiseConv5x5S1(const framework::Tensor &input, [kr4] "w"(_ker[4]), [ker0] "w"(_ker[5]), [ker1] "w"(_ker[6]) : "cc", "memory", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ // pad right if (padding_w) { float32x4_t row0 = vld1q_f32(input_ptr0); diff --git a/src/operators/math/depthwise_conv5x5.h b/src/operators/math/depthwise_conv5x5.h index d047bbfa1ac179e0ef0b1b6705e349890b25e800..11d96b078ac7314ef0f3de98614c1e4ebd4dbc95 100644 --- a/src/operators/math/depthwise_conv5x5.h +++ b/src/operators/math/depthwise_conv5x5.h @@ -17,7 +17,6 @@ limitations under the License. */ #include #include #include "framework/tensor.h" -#include "operators/math/conv_func.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/math/gemm.cpp b/src/operators/math/gemm.cpp index 869a61089621e8ed436944c26bc3cffc78159f46..1fa78d161621b3c7928a0ce6b554c14aac3fd6b6 100644 --- a/src/operators/math/gemm.cpp +++ b/src/operators/math/gemm.cpp @@ -27,390 +27,418 @@ namespace paddle_mobile { namespace operators { namespace math { -// 将A矩阵分块复制到连续内存(RowMajor) -void Gemm::PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const float *a0, *a1, *a2, *a3; - for (int i = 0; i < m - m_tail; i += MR) { - a0 = A + i * lda; - a1 = A + (i + 1) * lda; - a2 = A + (i + 2) * lda; - a3 = A + (i + 3) * lda; - for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - } - } - - if (m_tail != 0) { - a0 = &A(m - m_tail, 0); - a1 = a0 + lda; - a2 = a0 + 2 * lda; - a3 = a0 + 3 * lda; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *buffer++ = *a0++; - *buffer++ = *a1++; - *buffer++ = *a2++; - *buffer++ = *a3++; - } - } +#if __ARM_NEON +inline float32x4_t vandq_f32(float32x4_t x, uint32x4_t mask) { + return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); } +#endif void Gemm::PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; - for (int i = 0; i < i_length; i += MR) { + float *buffer, const bool parallel) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5}; + int remain_k = k & 0x3; + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k)); + + #pragma omp parallel for if (parallel) + for (int i = 0; i < m - 5; i += 6) { const float *a0 = A + i * lda; const float *a1 = A + (i + 1) * lda; const float *a2 = A + (i + 2) * lda; const float *a3 = A + (i + 3) * lda; const float *a4 = A + (i + 4) * lda; const float *a5 = A + (i + 5) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); - const float *a1 = a0 + lda; - const float *a2 = a0 + 2 * lda; - const float *a3 = a0 + 3 * lda; - const float *a4 = a0 + 4 * lda; - const float *a5 = a0 + 5 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } -} + float *out_ptr = buffer + i * k; -void Gemm::PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; -#pragma omp parallel for - for (int i = 0; i < i_length; i += MR) { - const float *a0 = A + i * lda; - const float *a1 = A + (i + 1) * lda; - const float *a2 = A + (i + 2) * lda; - const float *a3 = A + (i + 3) * lda; - const float *a4 = A + (i + 4) * lda; - const float *a5 = A + (i + 5) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); - const float *a1 = a0 + lda; - const float *a2 = a0 + 2 * lda; - const float *a3 = a0 + 3 * lda; - const float *a4 = a0 + 4 * lda; - const float *a5 = a0 + 5 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; + int loops = k >> 2; + if (loops > 0) { +#if __aarch64__ + for (int l = 0; l < loops; ++l) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = + vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = + vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_q3.val[1])); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; + } +#else + asm volatile( + "loop_4k_%=: \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + + "subs %[loops], #1 \n" + "bne loop_4k_%= \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif } - } -} -void Gemm::PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; - for (int i = 0; i < i_length; i += MR) { - const float *a0 = A + i * lda; - const float *a1 = A + (i + 1) * lda; - const float *a2 = A + (i + 2) * lda; - const float *a3 = A + (i + 3) * lda; - const float *a4 = A + (i + 4) * lda; - const float *a5 = A + (i + 5) * lda; - const float *a6 = A + (i + 6) * lda; - const float *a7 = A + (i + 7) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); + if (remain_k > 0) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + _d0 = vandq_f32(_d0, vmask1); + _d1 = vandq_f32(_d1, vmask1); + _d2 = vandq_f32(_d2, vmask1); + _d3 = vandq_f32(_d3, vmask1); + _d4 = vandq_f32(_d4, vmask1); + _d5 = vandq_f32(_d5, vmask1); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + + switch (remain_k) { + case 3: + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + case 2: + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + case 1: + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + default: + break; + } + } + } + + int remain_m = m % 6; + if (remain_m) { + int remain_m_start = m - remain_m; + const float *a0 = A + remain_m_start * lda; const float *a1 = a0 + lda; const float *a2 = a0 + 2 * lda; const float *a3 = a0 + 3 * lda; const float *a4 = a0 + 4 * lda; const float *a5 = a0 + 5 * lda; - const float *a6 = a0 + 6 * lda; - const float *a7 = a0 + 7 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - case 6: - a6 = zero; - case 7: - a7 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; + float *out_ptr = buffer + remain_m_start * k; + + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_m)); + uint32x4_t vmask3 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_m)); + + int loops = k >> 2; + if (loops > 0) { +#if __aarch64__ + for (int l = 0; l < loops; ++l) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = + vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = + vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + _d0 = vandq_f32(_d0, vmask2); + _d1 = vandq_f32(_d1, vmask2); + _d2 = vandq_f32(_d2, vmask2); + _d3 = vandq_f32(_d3, vmask2); + _d4 = vandq_f32(_q3.val[0], vmask3); + _d5 = vandq_f32(_q3.val[1], vmask3); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_d4)); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_d5)); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_d4)); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_d5)); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; + } +#else + asm volatile( + "loop_4k_%=: \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vbif q0, %q[vzero], %q[vmask2] \n" + "vbif q1, %q[vzero], %q[vmask2] \n" + "vbif q2, %q[vzero], %q[vmask2] \n" + "vbif q3, %q[vzero], %q[vmask2] \n" + "vbif q4, %q[vzero], %q[vmask3] \n" + "vbif q5, %q[vzero], %q[vmask3] \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + + "subs %[loops], #1 \n" + "bne loop_4k_%= \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops) + : [vmask2] "w"(vmask2), [vmask3] "w"(vmask3), [vzero] "w"(vzero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif } - } -} -void Gemm::PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer) { - const int i_length = m - m_tail; -#pragma omp parallel for - for (int i = 0; i < i_length; i += MR) { - const float *a0 = A + i * lda; - const float *a1 = A + (i + 1) * lda; - const float *a2 = A + (i + 2) * lda; - const float *a3 = A + (i + 3) * lda; - const float *a4 = A + (i + 4) * lda; - const float *a5 = A + (i + 5) * lda; - const float *a6 = A + (i + 6) * lda; - const float *a7 = A + (i + 7) * lda; - float *local_buffer = buffer + i * k; - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; - } - } - if (m_tail != 0) { - const float *a0 = &A(i_length, 0); - const float *a1 = a0 + lda; - const float *a2 = a0 + 2 * lda; - const float *a3 = a0 + 3 * lda; - const float *a4 = a0 + 4 * lda; - const float *a5 = a0 + 5 * lda; - const float *a6 = a0 + 6 * lda; - const float *a7 = a0 + 7 * lda; - float *local_buffer = buffer + i_length * k; - switch (m_tail) { - case 1: - a1 = zero; - case 2: - a2 = zero; - case 3: - a3 = zero; - case 4: - a4 = zero; - case 5: - a5 = zero; - case 6: - a6 = zero; - case 7: - a7 = zero; - break; - default: - break; - } - for (int j = 0; j < k; ++j) { - *local_buffer++ = *a0++; - *local_buffer++ = *a1++; - *local_buffer++ = *a2++; - *local_buffer++ = *a3++; - *local_buffer++ = *a4++; - *local_buffer++ = *a5++; - *local_buffer++ = *a6++; - *local_buffer++ = *a7++; + if (remain_k > 0) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + _d0 = vandq_f32(_d0, vmask1); + _d1 = vandq_f32(_d1, vmask1); + _d2 = vandq_f32(_d2, vmask1); + _d3 = vandq_f32(_d3, vmask1); + _d4 = vandq_f32(_d4, vmask1); + _d5 = vandq_f32(_d5, vmask1); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + // _d3 = vcombine_f32(vget_high_f32(_q0.val[1]), + // vget_high_f32(_q1.val[1])); + + _d0 = vandq_f32(_d0, vmask2); + _d1 = vandq_f32(_d1, vmask2); + _d2 = vandq_f32(_d2, vmask2); + // _d3 = vandq_f32(_d3, vmask2); + _d4 = vandq_f32(_q3.val[0], vmask3); + _d5 = vandq_f32(_q3.val[1], vmask3); + + switch (remain_k) { + case 3: + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_d4)); + case 2: + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_d5)); + case 1: + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_d4)); + default: + break; + } } } } // 将B矩阵分块复制到连续内存(RowMajor) void Gemm::PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { + float *buffer, const bool parallel) { const int j_length = n - n_tail; - for (int j = 0; j < j_length; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { + + #pragma omp parallel for if (parallel) + for (int i = 0; i < k; ++i) { + int j = 0; + for (; j < j_length - 31; j += 32) { + float *local_buffer0 = buffer + j * k + i * NR; + float *local_buffer1 = buffer + (j + 8) * k + i * NR; + float *local_buffer2 = buffer + (j + 16) * k + i * NR; + float *local_buffer3 = buffer + (j + 24) * k + i * NR; + const float *b0 = B + i * ldb + j; +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b0]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[b0]], #32 \n" + "ld1 {v6.4s, v7.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32 \n" + "st1 {v4.4s, v5.4s}, [%[local_buffer2]], #32 \n" + "st1 {v6.4s, v7.4s}, [%[local_buffer3]], #32 \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), + [local_buffer2] "+r"(local_buffer2), + [local_buffer3] "+r"(local_buffer3), [b0] "+r"(b0) + : + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); +#else + asm volatile( + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b0]]! \n" + "vld1.32 {q4, q5}, [%[b0]]! \n" + "vld1.32 {q6, q7}, [%[b0]]! \n" + "vst1.32 {q0, q1}, [%[local_buffer0]]! \n" + "vst1.32 {q2, q3}, [%[local_buffer1]]! \n" + "vst1.32 {q4, q5}, [%[local_buffer2]]! \n" + "vst1.32 {q6, q7}, [%[local_buffer3]]! \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), + [local_buffer2] "+r"(local_buffer2), + [local_buffer3] "+r"(local_buffer3), [b0] "+r"(b0) + : + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); +#endif // __aarch64__ + } + for (; j < j_length - 15; j += 16) { + float *local_buffer0 = buffer + j * k + i * NR; + float *local_buffer1 = buffer + (j + 8) * k + i * NR; const float *b0 = &B(i, j); #if __ARM_NEON #if __aarch64__ asm volatile( - "prfm pldl1keep, [%[b0]] \n\t" - "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "v0", "v1"); + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32 \n" + "st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32 \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), [b0] "+r"(b0) + : + : "memory", "v0", "v1", "v2", "v3"); #else asm volatile( - // "pld [%[b0]] \n\t" - "vld1.32 {q0, q1}, [%[b0]] \n\t" - "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "q0", "q1"); + // "pld [%[b0]] \n" + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b0]] \n" + "vst1.32 {q0, q1}, [%[local_buffer0]]! \n" + "vst1.32 {q2, q3}, [%[local_buffer1]]! \n" + : [local_buffer0] "+r"(local_buffer0), + [local_buffer1] "+r"(local_buffer1), [b0] "+r"(b0) + : + : "memory", "q0", "q1", "q2", "q3"); #endif // __aarch64__ -#else - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; #endif // __ARM_NEON } - } - if (n_tail != 0) { - float *local_buffer = buffer + j_length * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } - } - } -} - -void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { - const int j_length = n - n_tail; -#pragma omp parallel for - for (int j = 0; j < j_length; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { + for (; j < j_length; j += NR) { + float *local_buffer = buffer + j * k + i * NR; const float *b0 = &B(i, j); -#if __ARM_NEON #if __aarch64__ asm volatile( - "prfm pldl1keep, [%[b0]] \n\t" - "ld1 {v0.4s, v1.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n\t" + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n" : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "v0", "v1"); #else asm volatile( - // "pld [%[b0]] \n\t" - "vld1.32 {q0, q1}, [%[b0]] \n\t" - "vst1.32 {q0, q1}, [%[local_buffer]]! \n\t" + // "pld [%[b]] \n" + "vld1.32 {q0, q1}, [%[b0]] \n" + "vst1.32 {q0, q1}, [%[local_buffer]] \n" : [local_buffer] "+r"(local_buffer) : [b0] "r"(b0) : "memory", "q0", "q1"); #endif // __aarch64__ -#else - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; - *local_buffer++ = *b0++; -#endif // __ARM_NEON } } if (n_tail != 0) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(n_tail)); + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(n_tail)); + float *local_buffer = buffer + j_length * k; for (int i = 0; i < k; ++i) { const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } +#if __aarch64__ + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "BIF v0.8b, %[vzero].8b, %[vmask1].8b \n" + "BIF v1.8b, %[vzero].8b, %[vmask2].8b \n" + "st1 {v0.4s, v1.4s}, [%[local_buffer]], #32 \n" + : [local_buffer] "+r"(local_buffer) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero), + [b0] "r"(b0) + : "memory", "v0", "v1"); +#else + asm volatile( + "vld1.32 {q0, q1}, [%[b0]] \n" + "vbif q0, %q[vzero], %q[vmask1] \n" + "vbif q1, %q[vzero], %q[vmask2] \n" + "vst1.32 {q0, q1}, [%[local_buffer]]! \n" + : [local_buffer] "+r"(local_buffer) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [vzero] "w"(vzero), + [b0] "r"(b0) + : "memory", "q0", "q1"); +#endif } } } @@ -418,39 +446,10 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, #if __ARM_NEON #if __aarch64__ void Gemm::PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { + float *buffer, const bool parallel) { const int j_length = n - n_tail; - for (int j = 0; j < j_length; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j); - asm volatile( - "prfm pldl2keep, [%[b0], #64] \n\t" - "ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48 \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "v0", "v1", "v2"); - } - } - if (n_tail != 0) { - float *local_buffer = buffer + j_length * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } - } - } -} -void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, - int ldb, float *buffer) { - const int j_length = n - n_tail; -#pragma omp parallel for + #pragma omp parallel for if (parallel) for (int j = 0; j < j_length; j += NR) { float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { @@ -479,39 +478,10 @@ void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, } void Gemm::PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer) { + float *buffer, const bool parallel) { const int j_length = n - n_tail; - for (int j = 0; j < n - n_tail; j += NR) { - float *local_buffer = buffer + j * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j); - asm volatile( - "prfm pldl2keep, [%[b0], #64] \n\t" - "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]] \n\t" - "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64 \n\t" - : [local_buffer] "+r"(local_buffer) - : [b0] "r"(b0) - : "memory", "v0", "v1", "v2", "v3"); - } - } - if (n_tail != 0) { - float *local_buffer = buffer + j_length * k; - for (int i = 0; i < k; ++i) { - const float *b0 = &B(i, j_length); - for (int j = j_length; j < n; ++j) { - *local_buffer++ = *b0++; - } - for (int j = n; j < j_length + NR; ++j) { - *local_buffer++ = 0; - } - } - } -} -void Gemm::PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, - int ldb, float *buffer) { - const int j_length = n - n_tail; -#pragma omp parallel for + #pragma omp parallel for if (parallel) for (int j = 0; j < n - n_tail; j += NR) { float *local_buffer = buffer + j * k; for (int i = 0; i < k; ++i) { @@ -3285,25 +3255,23 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); int mc, nc; for (int j = 0; j < n; j += NC) { nc = s_min(n - j, NC); #if __aarch64__ // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); #if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); #endif if (bias == nullptr) { InnerKernelWithBias(mc, nc, alpha, packedA, packedB, beta, packedC, @@ -3318,7 +3286,6 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, @@ -3358,25 +3325,23 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); int mc, nc; for (int j = 0; j < n; j += NC) { nc = s_min(n - j, NC); #if __aarch64__ // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); #if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); #endif if (bias == nullptr) { InnerKernelWithBn(mc, nc, alpha, packedA, packedB, beta, packedC, @@ -3392,7 +3357,6 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, @@ -3431,28 +3395,23 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); packedC = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * NC)); - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - - for (int l = 0; l < KC; ++l) { - zero[l] = 0; - } int mc, nc; for (int j = 0; j < n; j += NC) { nc = s_min(n - j, NC); #if __aarch64__ // PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB); - PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_16c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #else - PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB); + PackMatrixB_8c(KC, nc, nc % NR, &B(0, j), ldb, packedB, false); #endif for (int i = 0; i < m; i += MC) { mc = s_min(m - i, MC); #if __aarch64__ - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); // PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA); #else - PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA); + PackMatrixA_6r(mc, KC, mc % MR, &A(i, 0), lda, packedA, false); #endif if (bias1 == nullptr) { InnerKernelWithPRelu(mc, nc, packedA, packedB, packedC, &C(i, j), ldc, @@ -3467,7 +3426,6 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } // 32位 float 矩阵乘法 @@ -3489,8 +3447,6 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, int L = (max_threads > 2) ? 64 : 32; int L1 = L / max_threads * 1024; KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3506,17 +3462,17 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, #if __aarch64__ procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; + procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; + procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3533,19 +3489,19 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, MC = (m + MR - 1) / MR * MR; #if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } @@ -3565,7 +3521,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, mc = s_min(m - i, MC); float *local_A = packedA + MC * KC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false); if (bias == nullptr) { InnerKernelWithBias(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0), ldc, relu, nullptr); @@ -3587,7 +3543,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, nc = s_min(n - j, NC); float *local_B = packedB + KC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false); InnerKernelWithBias(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j), ldc, relu, bias); } @@ -3596,7 +3552,6 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, @@ -3611,8 +3566,6 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, int L1 = 64 / max_threads * 1024; KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3628,17 +3581,17 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, #if __aarch64__ procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; + procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; + procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3655,18 +3608,18 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, MC = (m + MR - 1) / MR * MR; #if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } @@ -3686,7 +3639,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, mc = s_min(m - i, MC); float *local_A = packedA + MC * KC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false); if (bias == nullptr) { InnerKernelWithBn(mc, n, alpha, local_A, packedB, beta, local_C, &C(i, 0), ldc, relu, new_scale + i, new_bias + i); @@ -3709,7 +3662,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, nc = s_min(n - j, NC); float *local_B = packedB + KC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false); if (bias == nullptr) { InnerKernelWithBn(m, nc, alpha, packedA, local_B, beta, local_C, &C(0, j), ldc, relu, new_scale, new_bias); @@ -3724,7 +3677,6 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, @@ -3739,8 +3691,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, int L1 = 8 * 1024; KC = k; - zero = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * KC)); - memset(static_cast(zero), 0, sizeof(float) * KC); if (m > n) { // 对 A 分块 MC = L1 / (KC * sizeof(float)); @@ -3756,17 +3706,17 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, #if __aarch64__ procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_16c; + procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else procPackA = &Gemm::PackMatrixA_6r; - procPackB = &Gemm::PackMatrixB_omp_8c; + procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC)); - (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB); + (*this.*procPackB)(KC, n, n % NR, B, ldb, packedB, true); packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC * max_threads)); } else { @@ -3783,18 +3733,18 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, MC = (m + MR - 1) / MR * MR; #if __aarch64__ - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_16c; procAddDot = &Gemm::AddDot6x16; #else - procPackA = &Gemm::PackMatrixA_omp_6r; + procPackA = &Gemm::PackMatrixA_6r; procPackB = &Gemm::PackMatrixB_8c; procAddDot = &Gemm::AddDot6x8; #endif packedA = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * MC * KC)); - (*this.*procPackA)(m, KC, m % MR, A, lda, packedA); + (*this.*procPackA)(m, KC, m % MR, A, lda, packedA, true); packedB = static_cast( paddle_mobile::memory::Alloc(sizeof(float) * KC * NC * max_threads)); } @@ -3814,7 +3764,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, mc = s_min(m - i, MC); float *local_A = packedA + MC * KC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A); + (*this.*procPackA)(mc, KC, mc % MR, &A(i, 0), lda, local_A, false); if (bias1 == nullptr) { InnerKernelWithPRelu(mc, n, local_A, packedB, local_C, &C(i, 0), ldc, p + i, mode, bias + i, nullptr); @@ -3836,7 +3786,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, nc = s_min(n - j, NC); float *local_B = packedB + KC * NC * local_threads; float *local_C = packedC + MC * NC * local_threads; - (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B); + (*this.*procPackB)(KC, nc, nc % NR, &B(0, j), ldb, local_B, false); if (bias1 == nullptr) { InnerKernelWithPRelu(m, nc, packedA, local_B, local_C, &C(0, j), ldc, p, mode, bias, nullptr); @@ -3850,7 +3800,6 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda, paddle_mobile::memory::Free(packedA); paddle_mobile::memory::Free(packedB); paddle_mobile::memory::Free(packedC); - paddle_mobile::memory::Free(zero); } } // namespace math diff --git a/src/operators/math/gemm.h b/src/operators/math/gemm.h index effab20b2045fbe93590189e28bac24d1f72ab2c..113e04fe3c94d4f9edbfab0520ca881b9cbab4e7 100644 --- a/src/operators/math/gemm.h +++ b/src/operators/math/gemm.h @@ -46,37 +46,25 @@ namespace math { class Gemm { public: - typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *); + typedef void (Gemm::*FnPack)(int, int, int, const float *, int, float *, + const bool); typedef void (Gemm::*FnAddDot)(int, const float *, const float *, float *, int); FnPack procPackA; FnPack procPackB; FnAddDot procAddDot; - // 将 A\B 矩阵分块复制到连续内存(RowMajor) - void PackMatrixA_4r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); void PackMatrixA_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); - void PackMatrixA_omp_6r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); + float *buffer, const bool parallel); void PackMatrixA_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); - void PackMatrixA_omp_8r(int m, int k, int m_tail, const float *A, int lda, - float *buffer); + float *buffer, const bool parallel); void PackMatrixB_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); - void PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); + float *buffer, const bool parallel); #if __aarch64__ void PackMatrixB_12c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); - void PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); + float *buffer, const bool parallel); void PackMatrixB_16c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); - void PackMatrixB_omp_16c(int k, int n, int n_tail, const float *B, int ldb, - float *buffer); + float *buffer, const bool parallel); #endif // 分块矩阵乘法 @@ -272,7 +260,6 @@ class Gemm { float *packedA; float *packedB; float *packedC; - float *zero; // 8 bits int int8_t *packedA_int8; diff --git a/src/operators/math/gemm/cblas.cc b/src/operators/math/gemm/cblas.cc new file mode 100644 index 0000000000000000000000000000000000000000..adc375b62913f0ad1105080f8c26b547e96671f3 --- /dev/null +++ b/src/operators/math/gemm/cblas.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +#pragma once + +#include "operators/math/gemm/cblas.h" +#include "operators/math/gemm/cpu_info.h" +#include "operators/math/gemm/executor.h" +#include "operators/math/gemm/strategy.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, + const int K, const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, float *C, + const int ldc) { + if (N == 1) { + return cblas_sgemv(transA, M, K, alpha, A, lda, B, beta, C); + } else if (M == 1) { + return cblas_sgemv(!transB, N, K, alpha, B, ldb, A, beta, C); + } else { + GemmExecutor exec(transA, transB, M, N, K); + exec(alpha, A, lda, B, ldb, beta, C, ldc); + } +} + +void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C) { + GemvExecutor exec(trans, M, N); + exec(alpha, A, lda, B, beta, C); +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/conv_add_add_prelu_kernel.cpp b/src/operators/math/gemm/cblas.h similarity index 57% rename from src/operators/kernel/arm/conv_add_add_prelu_kernel.cpp rename to src/operators/math/gemm/cblas.h index 2f6f5f3ac719b3fd32aac54ce36eb534f7d99dd7..c7c9201869f56a7d339cccfcb3d898a4751836a6 100644 --- a/src/operators/kernel/arm/conv_add_add_prelu_kernel.cpp +++ b/src/operators/math/gemm/cblas.h @@ -12,28 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef FUSION_CONVADDADDPRELU_OP - -#include "operators/kernel/conv_add_add_prelu_kernel.h" -#include "operators/kernel/central-arm-func/conv_add_add_prelu_arm_func.h" +#pragma once namespace paddle_mobile { namespace operators { +namespace math { -template <> -bool ConvAddAddPReluKernel::Init( - FusionConvAddAddPReluParam *param) { - return true; -} +void cblas_sgemm(const bool transA, const bool transB, const int M, const int N, + const int K, const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, float *C, + const int ldc); -template <> -void ConvAddAddPReluKernel::Compute( - const FusionConvAddAddPReluParam ¶m) { - ConvAddAddPReluCompute(param); -} -template class ConvAddAddPReluKernel; +void cblas_sgemv(const bool trans, const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C); +} // namespace math } // namespace operators } // namespace paddle_mobile - -#endif diff --git a/src/operators/kernel/conv_add_add_prelu_kernel.h b/src/operators/math/gemm/cpu_info.h similarity index 51% rename from src/operators/kernel/conv_add_add_prelu_kernel.h rename to src/operators/math/gemm/cpu_info.h index fadaf7564ceeb7a52215dc335135016be02bc1ab..54975797c782be1964c562f38fd12edbcd6a2f0e 100644 --- a/src/operators/kernel/conv_add_add_prelu_kernel.h +++ b/src/operators/math/gemm/cpu_info.h @@ -14,32 +14,42 @@ limitations under the License. */ #pragma once -#ifdef FUSION_CONVADDADDPRELU_OP - -#include -#include "framework/ddim.h" -#include "framework/operator.h" -#include "operators/math/conv_func.h" -#include "operators/math/im2col.h" -#include "operators/math/math_function.h" -#include "operators/math/vol2col.h" -#include "operators/op_param.h" +#define MOBILE_MAX_CPU_NUM 8 namespace paddle_mobile { namespace operators { +namespace math { + +struct CPUInfo { + private: + CPUInfo() { + // TODO(hjchen2) + num_cpus = 4; + for (int i = 0; i < num_cpus; ++i) { + cpu_frequency[i] = 2400; // 2400 MHz + max_cpu_frequency[i] = 2400; // 2400 MHz + } + // L1_cache = 32000; // 32K + L1_cache = 32 * 1024; + L2_cache = 2000000; // 2M + // L2_cache = 512000; + } + virtual ~CPUInfo() {} -using framework::DDim; -using framework::OpKernelBase; - -template -class ConvAddAddPReluKernel - : public OpKernelBase> { public: - void Compute(const FusionConvAddAddPReluParam ¶m); - bool Init(FusionConvAddAddPReluParam *param); + static CPUInfo* Info() { + static CPUInfo* ctx = new CPUInfo; + return ctx; + } + + int num_cpus; + int cpu_frequency[MOBILE_MAX_CPU_NUM]; + int max_cpu_frequency[MOBILE_MAX_CPU_NUM]; + + int L1_cache; + int L2_cache; }; +} // namespace math } // namespace operators } // namespace paddle_mobile - -#endif diff --git a/src/operators/math/gemm/executor.h b/src/operators/math/gemm/executor.h new file mode 100644 index 0000000000000000000000000000000000000000..ddbed3dbdf6a5399b0f945d7da98ed536ee5e4e2 --- /dev/null +++ b/src/operators/math/gemm/executor.h @@ -0,0 +1,261 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#ifdef _OPENMP +#include +#endif +// #include +// #include +#include "common/log.h" +#include "memory/t_malloc.h" +#include "operators/math/gemm/cpu_info.h" +#include "operators/math/gemm/gemm_kernel.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +static CPUInfo *info = CPUInfo::Info(); + +int CeilDiv(const int &x, const int &y) { return (x + y - 1) / y; } +unsigned int ResetL1Cache(const unsigned int L1_size, const int thread_num, + const int N, const int K) { + unsigned int L1 = L1_size; + if (thread_num == 1) { + if (N >= 30000 && K > 100) { + L1 *= 4; + } else if (N >= 10000 && K > 100) { + L1 *= 2; + } + } + return L1; +} + +class Executor { + public: + Executor() : num_threads_(1) { +#ifdef _OPENMP + num_threads_ = omp_get_max_threads(); +#endif + } + virtual ~Executor() {} + + protected: + int num_threads_; +}; + +template +class GemmExecutor : public Executor { + typedef typename Strategy::Itype Itype; + typedef typename Strategy::Otype Otype; + + public: + GemmExecutor(const bool transA, const bool transB, const int M, const int N, + const int K) + : Executor(), transA_(transA), transB_(transB), M_(M), N_(N), K_(K) { + unsigned int L1_size = 0; + unsigned int L2_size = 0; + if (M_ > N_) { + L2_size = ResetL1Cache(info->L1_cache, num_threads_, M_, K_); + L1_size = info->L2_cache; + } else { + L1_size = ResetL1Cache(info->L1_cache, num_threads_, N_, K_); + L2_size = info->L2_cache; + } + + rhs_tile_num_ = L1_size / (K_ * sizeof(Itype)); + if (rhs_tile_num_ == 0) { + rhs_tile_num_ = Strategy::out_width(); + } else { + int n_block = CeilDiv(N_, rhs_tile_num_); + rhs_tile_num_ = CeilDiv(N_, n_block); + rhs_tile_num_ = CeilDiv(rhs_tile_num_, Strategy::out_width()); + rhs_tile_num_ *= Strategy::out_width(); + } + + // lhs_tile_num_ = CeilDiv(M, Strategy::out_height()) * + // Strategy::out_height(); + lhs_tile_num_ = L2_size / (K_ * sizeof(Itype)); + if (lhs_tile_num_ == 0) { + lhs_tile_num_ = Strategy::out_height(); + } else { + int m_block = CeilDiv(M_, lhs_tile_num_); + lhs_tile_num_ = CeilDiv(M_, m_block); + lhs_tile_num_ = CeilDiv(lhs_tile_num_, Strategy::out_height()); + lhs_tile_num_ *= Strategy::out_height(); + } + } + + void operator()(const float alpha, const Itype *A, const int lda, + const Itype *B, const int ldb, const float beta, Otype *C, + const int ldc) { + // struct timeval tv_begin, tv_end; + // gettimeofday(&tv_begin,NULL); + if (M_ > N_) { + int nblock = CeilDiv(N_, Strategy::out_width()) * Strategy::out_width(); + lhs_worksize_ = sizeof(Itype) * lhs_tile_num_ * K_; + rhs_worksize_ = sizeof(Itype) * K_ * nblock * num_threads_; + out_worksize_ = sizeof(Otype) * lhs_tile_num_ * nblock * num_threads_; + ldc_ = nblock; + } else { + int mblock = CeilDiv(M_, Strategy::out_height()) * Strategy::out_height(); + lhs_worksize_ = sizeof(Itype) * mblock * K_; + rhs_worksize_ = sizeof(Itype) * K_ * rhs_tile_num_ * num_threads_; + out_worksize_ = sizeof(Otype) * mblock * rhs_tile_num_ * num_threads_; + ldc_ = rhs_tile_num_; + } + + lhs_workspace_ = + static_cast(paddle_mobile::memory::Alloc(lhs_worksize_)); + rhs_workspace_ = + static_cast(paddle_mobile::memory::Alloc(rhs_worksize_)); + out_workspace_ = + static_cast(paddle_mobile::memory::Alloc(out_worksize_)); + + // std::cout << "M: " << M_ << ", N: " << N_ << ", K: " << K_ << std::endl; + // std::cout << "lhs_block: " << CeilDiv(M_, lhs_tile_num_) << ", " + // << "rhs_block: " << CeilDiv(N_, rhs_tile_num_) << std::endl; + + if (M_ > N_) { + strategy_.pack_rhs(K_, N_, B, ldb, rhs_workspace_, true); + + #pragma omp parallel for if (M_ > 128) + for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { + int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_); +#ifdef _OPENMP + int thread_id = omp_get_thread_num(); +#else + int thread_id = 0; +#endif + float *local_A = lhs_workspace_ + lhs_tile_num_ * K_ * thread_id; + float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id; + // load lhs into lhs_workspace + strategy_.pack_lhs(lhs_range, K_, A + lhs_block * lda, lda, local_A, + false); + for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { + int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_); + float *local_B = rhs_workspace_ + K_ * rhs_block; + for (int rhs_tile = 0; rhs_tile < rhs_range; + rhs_tile += Strategy::out_width()) { + for (int lhs_tile = 0; lhs_tile < lhs_range; + lhs_tile += Strategy::out_height()) { + int offset = lhs_tile * ldc_ + rhs_block + rhs_tile; + strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_, + K_, local_C + offset, ldc_); + } + } + } + strategy_.write(lhs_range, N_, alpha, local_C, ldc_, beta, + C + lhs_block * ldc, ldc); + } + } else { + strategy_.pack_lhs(M_, K_, A, lda, lhs_workspace_, true); + + #pragma omp parallel for if (N_ > 128) + for (int rhs_block = 0; rhs_block < N_; rhs_block += rhs_tile_num_) { + int rhs_range = std::min(N_ - rhs_block, rhs_tile_num_); +#ifdef _OPENMP + int thread_id = omp_get_thread_num(); +#else + int thread_id = 0; +#endif + float *local_B = rhs_workspace_ + K_ * rhs_tile_num_ * thread_id; + float *local_C = out_workspace_ + lhs_tile_num_ * ldc_ * thread_id; + // load rhs into rhs_workspace + strategy_.pack_rhs(K_, rhs_range, B + rhs_block, ldb, local_B, false); + for (int lhs_block = 0; lhs_block < M_; lhs_block += lhs_tile_num_) { + int lhs_range = std::min(M_ - lhs_block, lhs_tile_num_); + float *local_A = lhs_workspace_ + lhs_block * K_; + for (int lhs_tile = 0; lhs_tile < lhs_range; + lhs_tile += Strategy::out_height()) { + for (int rhs_tile = 0; rhs_tile < rhs_range; + rhs_tile += Strategy::out_width()) { + int offset = (lhs_block + lhs_tile) * ldc_ + rhs_tile; + strategy_.kernel(local_A + lhs_tile * K_, local_B + rhs_tile * K_, + K_, local_C + offset, ldc_); + } + } + } + strategy_.write(M_, rhs_range, alpha, local_C, ldc_, beta, + C + rhs_block, ldc); + } + } + + paddle_mobile::memory::Free(lhs_workspace_); + paddle_mobile::memory::Free(rhs_workspace_); + paddle_mobile::memory::Free(out_workspace_); + + // gettimeofday(&tv_end,NULL); + // float elapsed = (tv_end.tv_sec - tv_begin.tv_sec) * 1000.f + + // (tv_end.tv_usec - tv_begin.tv_usec) / 1000.f; + // std::cout << "elapsed: " << elapsed << "ms, speed: " + // << (M_ * N_ * K_ / 1000.f / 1000.f) / elapsed + // << " gflops" << std::endl; + } + + virtual ~GemmExecutor() {} + + private: + const unsigned int M_; + const unsigned int N_; + const unsigned int K_; + const bool transA_; + const bool transB_; + + unsigned int lhs_tile_num_ = 0; + unsigned int rhs_tile_num_ = 0; + unsigned int out_tile_num_ = 0; + + unsigned int lhs_worksize_ = 0; + unsigned int rhs_worksize_ = 0; + unsigned int out_worksize_ = 0; + unsigned int ldc_ = 0; + + Itype *lhs_workspace_ = nullptr; + Itype *rhs_workspace_ = nullptr; + Otype *out_workspace_ = nullptr; + + Strategy strategy_; +}; + +template +class GemvExecutor : public Executor { + typedef typename Strategy::Itype Itype; + typedef typename Strategy::Otype Otype; + + public: + GemvExecutor(const bool transA, const int M, const int N) + : Executor(), M_(M), N_(N), trans_(transA) {} + + void operator()(const float alpha, const Itype *A, const int lda, + const Itype *B, const float beta, Otype *C) { + strategy_.kernel(trans_, M_, N_, alpha, A, lda, B, beta, C); + } + + virtual ~GemvExecutor() {} + + private: + const unsigned int M_; + const unsigned int N_; + const bool trans_; + + Strategy strategy_; +}; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gemm/gemm_kernel.h b/src/operators/math/gemm/gemm_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..7cbbb09270acc9a58ccb414464e11879d97e2292 --- /dev/null +++ b/src/operators/math/gemm/gemm_kernel.h @@ -0,0 +1,526 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +#include +#include "operators/math/math.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +#if __aarch64__ +void sgemm_6x16(const float *lhs, const float *rhs, const int k, float *output, + const int ldc) { + int kc1 = k; + int step = 4 * ldc; + int step1 = 4 * 6; + asm volatile( + "dup v6.4s, wzr \n\t" + "dup v7.4s, wzr \n\t" + "dup v8.4s, wzr \n\t" + "dup v9.4s, wzr \n\t" + "dup v10.4s, wzr \n\t" + "dup v11.4s, wzr \n\t" + "dup v12.4s, wzr \n\t" + "dup v13.4s, wzr \n\t" + + "dup v14.4s, wzr \n\t" + "dup v15.4s, wzr \n\t" + "dup v16.4s, wzr \n\t" + "dup v17.4s, wzr \n\t" + "dup v18.4s, wzr \n\t" + "dup v19.4s, wzr \n\t" + "dup v20.4s, wzr \n\t" + "dup v21.4s, wzr \n\t" + + "dup v22.4s, wzr \n\t" + "dup v23.4s, wzr \n\t" + "dup v24.4s, wzr \n\t" + "dup v25.4s, wzr \n\t" + "dup v26.4s, wzr \n\t" + "dup v27.4s, wzr \n\t" + "dup v28.4s, wzr \n\t" + "dup v29.4s, wzr \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt 2f \n\t" + "1: \n\t" + + "prfm pldl1keep, [%[lhs], #24] \n\t" + "prfm pldl1keep, [%[rhs], #64] \n\t" + + "ld1 {v0.4s, v1.4s}, [%[lhs]], %[step1] \n\t" + "ld1 {v2.4s, v3.4s, v4.4s, v5.4s}, [%[rhs]], #64 \n\t" + + "fmla v6.4s, v2.4s, v0.s[0] \n\t" + "fmla v7.4s, v3.4s, v0.s[0] \n\t" + "fmla v8.4s, v4.4s, v0.s[0] \n\t" + "fmla v9.4s, v5.4s, v0.s[0] \n\t" + + "fmla v10.4s, v2.4s, v0.s[1] \n\t" + "fmla v11.4s, v3.4s, v0.s[1] \n\t" + "fmla v12.4s, v4.4s, v0.s[1] \n\t" + "fmla v13.4s, v5.4s, v0.s[1] \n\t" + + "fmla v14.4s, v2.4s, v0.s[2] \n\t" + "fmla v15.4s, v3.4s, v0.s[2] \n\t" + "fmla v16.4s, v4.4s, v0.s[2] \n\t" + "fmla v17.4s, v5.4s, v0.s[2] \n\t" + + "fmla v18.4s, v2.4s, v0.s[3] \n\t" + "fmla v19.4s, v3.4s, v0.s[3] \n\t" + "fmla v20.4s, v4.4s, v0.s[3] \n\t" + "fmla v21.4s, v5.4s, v0.s[3] \n\t" + + "fmla v22.4s, v2.4s, v1.s[0] \n\t" + "fmla v23.4s, v3.4s, v1.s[0] \n\t" + "fmla v24.4s, v4.4s, v1.s[0] \n\t" + "fmla v25.4s, v5.4s, v1.s[0] \n\t" + + "fmla v26.4s, v2.4s, v1.s[1] \n\t" + "fmla v27.4s, v3.4s, v1.s[1] \n\t" + "fmla v28.4s, v4.4s, v1.s[1] \n\t" + "fmla v29.4s, v5.4s, v1.s[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 1b \n\t" + "2: \n\t" + + "st1 {v6.4s, v7.4s, v8.4s, v9.4s}, [%[c]], %[step] \n\t" + "st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [%[c]], %[step] \n\t" + "st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [%[c]], %[step] \n\t" + "st1 {v18.4s, v19.4s, v20.4s, v21.4s}, [%[c]], %[step] \n\t" + "st1 {v22.4s, v23.4s, v24.4s, v25.4s}, [%[c]], %[step] \n\t" + "st1 {v26.4s, v27.4s, v28.4s, v29.4s}, [%[c]], %[step] \n\t" + : [lhs] "+r"(lhs), [rhs] "+r"(rhs), [c] "+r"(output), [kc1] "+r"(kc1) + : [step] "r"(step), [step1] "r"(step1) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29"); +} +#else +void sgemm_6x8(const float *lhs, const float *rhs, const int k, float *output, + const int ldc) { + int kc1 = k >> 3; // k / 8 + int kc2 = k & 0x7; // k % 8 + int step = sizeof(float) * ldc; + asm volatile( + "pld [%[lhs]] \n\t" + "pld [%[lhs], #64] \n\t" + "pld [%[rhs]] \n\t" + "pld [%[rhs], #64] \n\t" + + "vmov.f32 q4, #0.0 \n\t" + "vmov.f32 q5, #0.0 \n\t" + "vmov.f32 q6, #0.0 \n\t" + "vmov.f32 q7, #0.0 \n\t" + "vmov.f32 q8, #0.0 \n\t" + "vmov.f32 q9, #0.0 \n\t" + "vmov.f32 q10, #0.0 \n\t" + "vmov.f32 q11, #0.0 \n\t" + "vmov.f32 q12, #0.0 \n\t" + "vmov.f32 q13, #0.0 \n\t" + "vmov.f32 q14, #0.0 \n\t" + "vmov.f32 q15, #0.0 \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "blt 2f \n\t" + "1: \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "pld [%[lhs], #128] \n\t" + "pld [%[rhs], #128] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc1], %[kc1], #1 \n\t" + "bge 1b \n\t" + "2: \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "blt 4f \n\t" + "3: \n\t" + + "vld1.32 {d0-d2}, [%[lhs]]! \n\t" + "vld1.32 {q2, q3}, [%[rhs]]! \n\t" + + "vmla.f32 q4, q2, d0[0] \n\t" + "vmla.f32 q5, q3, d0[0] \n\t" + "vmla.f32 q6, q2, d0[1] \n\t" + "vmla.f32 q7, q3, d0[1] \n\t" + "vmla.f32 q8, q2, d1[0] \n\t" + "vmla.f32 q9, q3, d1[0] \n\t" + "vmla.f32 q10, q2, d1[1] \n\t" + "vmla.f32 q11, q3, d1[1] \n\t" + "vmla.f32 q12, q2, d2[0] \n\t" + "vmla.f32 q13, q3, d2[0] \n\t" + "vmla.f32 q14, q2, d2[1] \n\t" + "vmla.f32 q15, q3, d2[1] \n\t" + + "subs %[kc2], %[kc2], #1 \n\t" + "bge 3b \n\t" + "4: \n\t" + + "mov r5, %[c] \n\t" + "mov r6, %[step] \n\t" + "vst1.32 {q4, q5}, [r5], r6 \n\t" + "vst1.32 {q6, q7}, [r5], r6 \n\t" + "vst1.32 {q8, q9}, [r5], r6 \n\t" + "vst1.32 {q10, q11}, [r5], r6 \n\t" + "vst1.32 {q12, q13}, [r5], r6 \n\t" + "vst1.32 {q14, q15}, [r5] \n\t" + : + : [lhs] "r"(lhs), [rhs] "r"(rhs), [c] "r"(output), [kc1] "r"(kc1), + [kc2] "r"(kc2), [step] "r"(step) + : "cc", "memory", "r5", "r6", "q0", "q1", "q2", "q3", "q4", "q5", "q6", + "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +} +#endif // __aarch64__ + +void sgemv_notrans_mx1(const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C) { + uint32_t mask[4] = {0, 1, 2, 3}; + int remain_n = N & 0x3; + uint32x4_t vmask = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n)); + float32x4_t _valpha = vdupq_n_f32(alpha); + + #pragma omp parallel for + for (int m = 0; m < M - 3; m += 4) { + const float *in0 = A + m * lda; + const float *in1 = in0 + lda; + const float *in2 = in1 + lda; + const float *in3 = in2 + lda; + float *output = C + m; + + float32x4_t _sum0, _sum1, _sum2, _sum3; + _sum0 = vdupq_n_f32(0.f); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); + int n = 0; + for (; n < N - 3; n += 4) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _r1 = vld1q_f32(in1 + n); + float32x4_t _r2 = vld1q_f32(in2 + n); + float32x4_t _r3 = vld1q_f32(in3 + n); + float32x4_t _b = vld1q_f32(B + n); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + _sum1 = vmlaq_f32(_sum1, _r1, _b); + _sum2 = vmlaq_f32(_sum2, _r2, _b); + _sum3 = vmlaq_f32(_sum3, _r3, _b); + } + if (n < N) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _r1 = vld1q_f32(in1 + n); + float32x4_t _r2 = vld1q_f32(in2 + n); + float32x4_t _r3 = vld1q_f32(in3 + n); + float32x4_t _b = vld1q_f32(B + n); + _r0 = vandq_f32_u32(_r0, vmask); + _r1 = vandq_f32_u32(_r1, vmask); + _r2 = vandq_f32_u32(_r2, vmask); + _r3 = vandq_f32_u32(_r3, vmask); + _b = vandq_f32_u32(_b, vmask); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + _sum1 = vmlaq_f32(_sum1, _r1, _b); + _sum2 = vmlaq_f32(_sum2, _r2, _b); + _sum3 = vmlaq_f32(_sum3, _r3, _b); + } + _sum0 = vpaddq_f32(_sum0, _sum1); + _sum2 = vpaddq_f32(_sum2, _sum3); + _sum0 = vpaddq_f32(_sum0, _sum2); + _sum0 = vmulq_f32(_sum0, _valpha); + if (beta != 0.f) { + _sum2 = vmulq_n_f32(vld1q_f32(output), beta); + _sum0 = vaddq_f32(_sum0, _sum2); + } + // restore + vst1q_f32(output, _sum0); + } + // remain m + for (int m = (M & 0xfffc); m < M; ++m) { + const float *in0 = A + m * lda; + float *output = C + m; + float32x4_t _sum0 = vdupq_n_f32(0.f); + + int n = 0; + for (; n < N - 3; n += 4) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _b = vld1q_f32(B + n); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + } + if (n < N) { + float32x4_t _r0 = vld1q_f32(in0 + n); + float32x4_t _b = vld1q_f32(B + n); + _r0 = vandq_f32_u32(_r0, vmask); + _b = vandq_f32_u32(_b, vmask); + _sum0 = vmlaq_f32(_sum0, _r0, _b); + } + _sum0 = vpaddq_f32(_sum0, _sum0); + _sum0 = vmulq_f32(_sum0, _valpha); + if (beta != 0.f) { + float32x4_t _sum2 = vmulq_n_f32(vld1q_f32(output), beta); + _sum0 = vpaddq_f32(_sum0, _sum2); + } + // restore + *output = vgetq_lane_f32(_sum0, 0) + vgetq_lane_f32(_sum0, 1); + } +} + +void sgemv_trans_mx1(const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, + const float beta, float *C) { + float32x4_t _valpha = vdupq_n_f32(alpha); + if (beta == 0.f) { + float32x4_t vzero = vdupq_n_f32(0.f); + for (int m = 0; m < M - 3; m += 4) { + vst1q_f32(C + m, vzero); + } + for (int m = (M & 0xfffc); m < M; ++m) { + C[m] = 0.f; + } + } else { + float32x4_t vbeta = vdupq_n_f32(beta); + for (int m = 0; m < M - 3; m += 4) { + float32x4_t _vc = vld1q_f32(C + m); + _vc = vmulq_f32(_vc, vbeta); + vst1q_f32(C + m, _vc); + } + for (int m = (M & 0xfffc); m < M; ++m) { + C[m] *= beta; + } + } + + #pragma omp parallel for + for (int n = 0; n < N - 3; n += 4) { + const float *in0 = A + n * lda; + const float *in1 = in0 + lda; + const float *in2 = in1 + lda; + const float *in3 = in2 + lda; + float32x4_t _b = vld1q_f32(B + n); + float32x4_t _sum0; + int m = 0; + for (; m < M - 3; m += 4) { + float32x4_t _r0 = vld1q_f32(in0 + m); + float32x4_t _r1 = vld1q_f32(in1 + m); + float32x4_t _r2 = vld1q_f32(in2 + m); + float32x4_t _r3 = vld1q_f32(in3 + m); + float32x4_t _vc = vld1q_f32(C + m); + + _sum0 = vmulq_lane_f32(_r0, vget_low_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r1, vget_low_f32(_b), 1); + _sum0 = vmlaq_lane_f32(_sum0, _r2, vget_high_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r3, vget_high_f32(_b), 1); + _sum0 = vmulq_f32(_sum0, _valpha); + _sum0 = vaddq_f32(_sum0, _vc); + vst1q_f32(C + m, _sum0); + } + if (m < M) { + float32x4_t _r0 = vld1q_f32(in0 + m); + float32x4_t _r1 = vld1q_f32(in1 + m); + float32x4_t _r2 = vld1q_f32(in2 + m); + float32x4_t _r3 = vld1q_f32(in3 + m); + float32x4_t _vc = vld1q_f32(C + m); + + _sum0 = vmulq_lane_f32(_r0, vget_low_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r1, vget_low_f32(_b), 1); + _sum0 = vmlaq_lane_f32(_sum0, _r2, vget_high_f32(_b), 0); + _sum0 = vmlaq_lane_f32(_sum0, _r3, vget_high_f32(_b), 1); + _sum0 = vmulq_f32(_sum0, _valpha); + _sum0 = vaddq_f32(_sum0, _vc); + switch (M - m) { + case 3: + vst1q_lane_f32(C + m + 2, _sum0, 2); + case 2: + vst1_f32(C + m, vget_low_f32(_sum0)); + break; + case 1: + vst1q_lane_f32(C + m, _sum0, 0); + break; + } + } + } + // remain n + for (int n = (N & 0xfffc); n < N; ++n) { + const float *in0 = A + n * lda; + float32x4_t _b = vld1q_dup_f32(B + n); + float32x4_t _sum0; + int m = 0; + for (; m < M - 3; m += 4) { + float32x4_t _r0 = vld1q_f32(in0 + m); + _sum0 = vld1q_f32(C + m); + _r0 = vmulq_f32(_r0, _b); + _r0 = vmulq_f32(_valpha, _r0); + _sum0 = vaddq_f32(_sum0, _r0); + vst1q_f32(C + m, _sum0); + } + for (; m < M; ++m) { + C[m] += alpha * (in0[m] * B[n]); + } + } +} + +void sgemv_mx1(const bool trans, const int M, const int N, const float alpha, + const float *A, const int lda, const float *B, const float beta, + float *C) { + if (trans) { + sgemv_trans_mx1(M, N, alpha, A, lda, B, beta, C); + } else { + sgemv_notrans_mx1(M, N, alpha, A, lda, B, beta, C); + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/gemm/pack_kernel.h b/src/operators/math/gemm/pack_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b1f6a9d35ec5630a1bc0ae9fc997dcb05419f0aa --- /dev/null +++ b/src/operators/math/gemm/pack_kernel.h @@ -0,0 +1,801 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) + +#include +#ifdef _OPENMP +#include +#endif +#include "operators/math/math.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +void pack_lhs_6r(const int m, const int k, const float *A, const int lda, + float *output, const bool unroll) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 4, 5}; + int remain_k = k & 0x3; + uint32x4_t vzero = vdupq_n_u32(0); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_k)); + + #pragma omp parallel for if (unroll) + for (int i = 0; i < m - 5; i += 6) { + const float *a0 = A + i * lda; + const float *a1 = A + (i + 1) * lda; + const float *a2 = A + (i + 2) * lda; + const float *a3 = A + (i + 3) * lda; + const float *a4 = A + (i + 4) * lda; + const float *a5 = A + (i + 5) * lda; + float *out_ptr = output + i * k; + + int loops = k >> 2; + if (loops > 0) { +#if __aarch64__ + for (int l = 0; l < loops; ++l) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = + vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = + vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_q3.val[1])); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; + } +#else + asm volatile( + "loop_4k_%=: \n" + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + + "subs %[loops], #1 \n" + "bne loop_4k_%= \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5), [loops] "+r"(loops) + : + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif + } + + if (remain_k > 0) { + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + _d0 = vandq_f32_u32(_d0, vmask1); + _d1 = vandq_f32_u32(_d1, vmask1); + _d2 = vandq_f32_u32(_d2, vmask1); + _d3 = vandq_f32_u32(_d3, vmask1); + _d4 = vandq_f32_u32(_d4, vmask1); + _d5 = vandq_f32_u32(_d5, vmask1); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + + switch (remain_k) { + case 3: + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_q3.val[0])); + case 2: + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_q3.val[1])); + case 1: + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_q3.val[0])); + default: + break; + } + } + } + + int remain_m = m % 6; + if (remain_m) { + int remain_m_start = m - remain_m; + const float *a0 = A + remain_m_start * lda; + const float *a1 = a0 + lda; + const float *a2 = a0 + 2 * lda; + const float *a3 = a0 + 3 * lda; + const float *a4 = a0 + 4 * lda; + const float *a5 = a0 + 5 * lda; + float *out_ptr = output + remain_m_start * k; + + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_m)); + uint32x4_t vmask3 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_m)); + const float zerobuff[4] = {0.f, 0.f, 0.f, 0.f}; + + int lk = 0; + for (; lk < k - 3; lk += 4) { + switch (remain_m) { + case 1: + a1 = zerobuff; + case 2: + a2 = zerobuff; + case 3: + a3 = zerobuff; + case 4: + a4 = zerobuff; + case 5: + a5 = zerobuff; + default: + break; + } +#if __aarch64__ + float32x4_t _d0 = vld1q_f32(a0); + float32x4_t _d1 = vld1q_f32(a1); + float32x4_t _d2 = vld1q_f32(a2); + float32x4_t _d3 = vld1q_f32(a3); + float32x4_t _d4 = vld1q_f32(a4); + float32x4_t _d5 = vld1q_f32(a5); + + float32x4x2_t _q0 = vtrnq_f32(_d0, _d1); + float32x4x2_t _q1 = vtrnq_f32(_d2, _d3); + float32x4x2_t _q3 = vtrnq_f32(_d4, _d5); + _d0 = vcombine_f32(vget_low_f32(_q0.val[0]), vget_low_f32(_q1.val[0])); + _d1 = vcombine_f32(vget_low_f32(_q0.val[1]), vget_low_f32(_q1.val[1])); + _d2 = vcombine_f32(vget_high_f32(_q0.val[0]), vget_high_f32(_q1.val[0])); + _d3 = vcombine_f32(vget_high_f32(_q0.val[1]), vget_high_f32(_q1.val[1])); + + _d0 = vandq_f32_u32(_d0, vmask2); + _d1 = vandq_f32_u32(_d1, vmask2); + _d2 = vandq_f32_u32(_d2, vmask2); + _d3 = vandq_f32_u32(_d3, vmask2); + _d4 = vandq_f32_u32(_q3.val[0], vmask3); + _d5 = vandq_f32_u32(_q3.val[1], vmask3); + + vst1q_f32(out_ptr, _d0); + vst1_f32(out_ptr + 4, vget_low_f32(_d4)); + vst1q_f32(out_ptr + 6, _d1); + vst1_f32(out_ptr + 10, vget_low_f32(_d5)); + vst1q_f32(out_ptr + 12, _d2); + vst1_f32(out_ptr + 16, vget_high_f32(_d4)); + vst1q_f32(out_ptr + 18, _d3); + vst1_f32(out_ptr + 22, vget_high_f32(_d5)); + + a0 += 4; + a1 += 4; + a2 += 4; + a3 += 4; + a4 += 4; + a5 += 4; + out_ptr += 24; +#else + asm volatile( + "vld1.32 {d0-d1}, [%[a0]]! \n" + "vld1.32 {d2-d3}, [%[a1]]! \n" + "vld1.32 {d4-d5}, [%[a2]]! \n" + "vld1.32 {d6-d7}, [%[a3]]! \n" + "vld1.32 {d8-d9}, [%[a4]]! \n" + "vld1.32 {d10-d11}, [%[a5]]! \n" + "vtrn.32 q0, q1 \n" + "vtrn.32 q2, q3 \n" + "vtrn.32 q4, q5 \n" + "vswp.32 d1, d4 \n" + "vswp.32 d3, d6 \n" + + "vbif q0, %q[vzero], %q[vmask2] \n" + "vbif q1, %q[vzero], %q[vmask2] \n" + "vbif q2, %q[vzero], %q[vmask2] \n" + "vbif q3, %q[vzero], %q[vmask2] \n" + "vbif q4, %q[vzero], %q[vmask3] \n" + "vbif q5, %q[vzero], %q[vmask3] \n" + + "vst1.32 {q0}, [%[out]]! \n" + "vst1.32 {d8}, [%[out]]! \n" + "vst1.32 {q1}, [%[out]]! \n" + "vst1.32 {d10}, [%[out]]! \n" + "vst1.32 {q2}, [%[out]]! \n" + "vst1.32 {d9}, [%[out]]! \n" + "vst1.32 {q3}, [%[out]]! \n" + "vst1.32 {d11}, [%[out]]! \n" + : [out] "+r"(out_ptr), [a0] "+r"(a0), [a1] "+r"(a1), [a2] "+r"(a2), + [a3] "+r"(a3), [a4] "+r"(a4), [a5] "+r"(a5) + : [vmask2] "w"(vmask2), [vmask3] "w"(vmask3), [vzero] "w"(vzero) + : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5"); +#endif + } + // remain k + switch (remain_m) { + case 1: + a1 = zerobuff; + case 2: + a2 = zerobuff; + case 3: + a3 = zerobuff; + case 4: + a4 = zerobuff; + case 5: + a5 = zerobuff; + default: + break; + } + for (; lk < k; ++lk) { + *out_ptr++ = *a0++; + *out_ptr++ = *a1++; + *out_ptr++ = *a2++; + *out_ptr++ = *a3++; + *out_ptr++ = *a4++; + *out_ptr++ = *a5++; + } + } +} + +#if __aarch64__ +void pack_rhs_16c(int k, int n, const float *B, int ldb, float *output, + const bool unroll) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + uint32_t remain_n = n & 0x7; + float32x4_t vzero = vdupq_n_f32(0.f); + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n)); + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_n)); + + #pragma omp parallel for if (unroll) + for (int i = 0; i < k - 3; i += 4) { + const float *b0 = B + i * ldb; + const float *b1 = b0 + ldb; + const float *b2 = b1 + ldb; + const float *b3 = b2 + ldb; + int j = 0; + asm volatile( + "prfm pldl1keep, [%[b0]] \n" + "prfm pldl1keep, [%[b1]] \n" + "prfm pldl1keep, [%[b2]] \n" + "prfm pldl1keep, [%[b3]] \n" + : + : [b0] "r"(b0), [b1] "r"(b1), [b2] "r"(b2), [b3] "r"(b3)); + + for (; j < n - 15; j += 16) { + float *out_ptr0 = output + j * k + 16 * i; + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]], #64 \n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b1]], #64 \n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[out_ptr0]], #64 \n" + "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[out_ptr0]], #64 \n" + + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b2]], #64 \n" + "ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[b3]], #64 \n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[out_ptr0]], #64 \n" + "st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [%[out_ptr0]], #64 \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1), + [b2] "+r"(b2), [b3] "+r"(b3) + : + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); + } + for (; j < n - 7; j += 8) { + float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + int step = 64; + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[b1]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[b2]], #32 \n" + "ld1 {v6.4s, v7.4s}, [%[b3]], #32 \n" + + "st1 {v0.4s, v1.4s}, [%[out_ptr0]], %[step] \n" + "st1 {v2.4s, v3.4s}, [%[out_ptr0]], %[step] \n" + "st1 {v4.4s, v5.4s}, [%[out_ptr0]], %[step] \n" + "st1 {v6.4s, v7.4s}, [%[out_ptr0]], %[step] \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1), + [b2] "+r"(b2), [b3] "+r"(b3) + : [step] "r"(step) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); + } + if (j < n) { + float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + int step = 64; + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "ld1 {v2.4s, v3.4s}, [%[b1]] \n" + "ld1 {v4.4s, v5.4s}, [%[b2]] \n" + "ld1 {v6.4s, v7.4s}, [%[b3]] \n" + + "and v0.16b, v0.16b, %[vmask1].16b \n" + "and v1.16b, v1.16b, %[vmask2].16b \n" + "and v2.16b, v2.16b, %[vmask1].16b \n" + "and v3.16b, v3.16b, %[vmask2].16b \n" + "and v4.16b, v4.16b, %[vmask1].16b \n" + "and v5.16b, v5.16b, %[vmask2].16b \n" + "and v6.16b, v6.16b, %[vmask1].16b \n" + "and v7.16b, v7.16b, %[vmask2].16b \n" + + "st1 {v0.4s, v1.4s}, [%[out_ptr0]], %[step] \n" + "st1 {v2.4s, v3.4s}, [%[out_ptr0]], %[step] \n" + "st1 {v4.4s, v5.4s}, [%[out_ptr0]], %[step] \n" + "st1 {v6.4s, v7.4s}, [%[out_ptr0]], %[step] \n" + : [out_ptr0] "+r"(out_ptr0) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0), + [b1] "r"(b1), [b2] "r"(b2), [b3] "r"(b3), [step] "r"(step) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); + j += 8; + } + + if (j & 0xf) { + float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + vst1q_f32(out_ptr0, vzero); + vst1q_f32(out_ptr0 + 4, vzero); + out_ptr0 += 16; + vst1q_f32(out_ptr0, vzero); + vst1q_f32(out_ptr0 + 4, vzero); + out_ptr0 += 16; + vst1q_f32(out_ptr0, vzero); + vst1q_f32(out_ptr0 + 4, vzero); + out_ptr0 += 16; + vst1q_f32(out_ptr0, vzero); + vst1q_f32(out_ptr0 + 4, vzero); + } + } + // remain k + for (int i = (k & 0xFFFC); i < k; ++i) { + const float *b0 = B + i * ldb; + int j = 0; + asm volatile("prfm pldl1keep, [%[b0]] \n" + : + : [b0] "r"(b0)); + + for (; j < n - 15; j += 16) { + float *out_ptr0 = output + j * k + 16 * i; + asm volatile( + "ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]], #64 \n" + "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[out_ptr0]], #64 \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0) + : + : "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"); + } + for (; j < n - 7; j += 8) { + float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + int step = 64; + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[b0]], #32 \n" + "st1 {v0.4s, v1.4s}, [%[out_ptr0]], %[step] \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0) + : [step] "r"(step) + : "memory", "v0", "v1"); + } + if (j < n) { + float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + asm volatile( + "ld1 {v0.4s, v1.4s}, [%[b0]] \n" + "and v0.16b, v0.16b, %[vmask1].16b \n" + "and v1.16b, v1.16b, %[vmask2].16b \n" + "st1 {v0.4s, v1.4s}, [%[out_ptr0]] \n" + : [out_ptr0] "+r"(out_ptr0) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0) + : "memory", "v0", "v1"); + j += 8; + } + if (j & 0xf) { + float *out_ptr0 = output + (j & 0xFFF0) * k + 16 * i + (j & 0xF); + vst1q_f32(out_ptr0, vzero); + vst1q_f32(out_ptr0 + 4, vzero); + } + } +} +#else + +void pack_rhs_8c(int k, int n, const float *B, int ldb, float *output, + const bool unroll) { + uint32_t mask[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + uint32_t remain_n = n & 0x7; + uint32x4_t vmask1 = vcltq_u32(vld1q_u32(mask), vdupq_n_u32(remain_n)); + uint32x4_t vmask2 = vcltq_u32(vld1q_u32(mask + 4), vdupq_n_u32(remain_n)); + + #pragma omp parallel for if (unroll) + for (int i = 0; i < k - 3; i += 4) { + const float *b0 = B + i * ldb; + const float *b1 = b0 + ldb; + const float *b2 = b1 + ldb; + const float *b3 = b2 + ldb; + int j = 0; + for (; j < n - 15; j += 16) { + float *out_ptr0 = output + j * k + 8 * i; + float *out_ptr1 = out_ptr0 + 8 * k; + asm volatile( + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b1]]! \n" + "vld1.32 {q4, q5}, [%[b0]]! \n" + "vld1.32 {q6, q7}, [%[b1]]! \n" + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + "vst1.32 {q2, q3}, [%[out_ptr0]]! \n" + "vst1.32 {q4, q5}, [%[out_ptr1]]! \n" + "vst1.32 {q6, q7}, [%[out_ptr1]]! \n" + + "vld1.32 {q0, q1}, [%[b2]]! \n" + "vld1.32 {q2, q3}, [%[b3]]! \n" + "vld1.32 {q4, q5}, [%[b2]]! \n" + "vld1.32 {q6, q7}, [%[b3]]! \n" + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + "vst1.32 {q2, q3}, [%[out_ptr0]]! \n" + "vst1.32 {q4, q5}, [%[out_ptr1]]! \n" + "vst1.32 {q6, q7}, [%[out_ptr1]]! \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0), + [b1] "+r"(b1), [b2] "+r"(b2), [b3] "+r"(b3) + : + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + } + for (; j < n - 7; j += 8) { + float *out_ptr0 = output + j * k + 8 * i; + asm volatile( + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b1]]! \n" + "vld1.32 {q4, q5}, [%[b2]]! \n" + "vld1.32 {q6, q7}, [%[b3]]! \n" + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + "vst1.32 {q2, q3}, [%[out_ptr0]]! \n" + "vst1.32 {q4, q5}, [%[out_ptr0]]! \n" + "vst1.32 {q6, q7}, [%[out_ptr0]]! \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0), [b1] "+r"(b1), + [b2] "+r"(b2), [b3] "+r"(b3) + : + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + } + if (j < n) { + float *out_ptr0 = output + j * k + 8 * i; + asm volatile( + "vld1.32 {q0, q1}, [%[b0]] \n" + "vld1.32 {q2, q3}, [%[b1]] \n" + "vld1.32 {q4, q5}, [%[b2]] \n" + "vld1.32 {q6, q7}, [%[b3]] \n" + "vand q0, q0, %q[vmask1] \n" + "vand q1, q1, %q[vmask2] \n" + "vand q2, q2, %q[vmask1] \n" + "vand q3, q3, %q[vmask2] \n" + "vand q4, q4, %q[vmask1] \n" + "vand q5, q5, %q[vmask2] \n" + "vand q6, q6, %q[vmask1] \n" + "vand q7, q7, %q[vmask2] \n" + + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + "vst1.32 {q2, q3}, [%[out_ptr0]]! \n" + "vst1.32 {q4, q5}, [%[out_ptr0]]! \n" + "vst1.32 {q6, q7}, [%[out_ptr0]]! \n" + : [out_ptr0] "+r"(out_ptr0) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0), + [b1] "r"(b1), [b2] "r"(b2), [b3] "r"(b3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7"); + } + } + // remain k + for (int i = (k & 0xFFFC); i < k; ++i) { + const float *b0 = B + i * ldb; + int j = 0; + for (; j < n - 15; j += 16) { + float *out_ptr0 = output + j * k + 8 * i; + float *out_ptr1 = out_ptr0 + 8 * k; + asm volatile( + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vld1.32 {q2, q3}, [%[b0]]! \n" + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + "vst1.32 {q2, q3}, [%[out_ptr1]]! \n" + : [out_ptr0] "+r"(out_ptr0), [out_ptr1] "+r"(out_ptr1), [b0] "+r"(b0) + : + : "memory", "q0", "q1", "q2", "q3"); + } + for (; j < n - 7; j += 8) { + float *out_ptr0 = output + j * k + 8 * i; + asm volatile( + "vld1.32 {q0, q1}, [%[b0]]! \n" + "vst1.32 {q0, q1}, [%[out_ptr0]]! \n" + : [out_ptr0] "+r"(out_ptr0), [b0] "+r"(b0) + : + : "memory", "q0", "q1"); + } + if (j < n) { + float *out_ptr0 = output + j * k + 8 * i; + asm volatile( + "vld1.32 {q0, q1}, [%[b0]] \n" + "vand q0, q0, %q[vmask1] \n" + "vand q1, q1, %q[vmask2] \n" + "vst1.32 {q0, q1}, [%[out_ptr0]] \n" + : [out_ptr0] "+r"(out_ptr0) + : [vmask1] "w"(vmask1), [vmask2] "w"(vmask2), [b0] "r"(b0) + : "memory", "q0", "q1"); + } + } +} +#endif // __aarch64__ + +void write_back_alpha_beta(const int mc, const int nc, const float alpha, + const float *c, const int ldc1, const float beta, + float *C, const int ldc2) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + float32x4_t _alpha = vdupq_n_f32(alpha); + float32x4_t _beta = vdupq_n_f32(beta); + float32x4_t cv, cv2; + for (int i = 0; i < mc; ++i) { + const float *c_ptr = c + i * ldc1; + float *C_ptr = C + i * ldc2; + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv = vmulq_f32(_alpha, cv); + cv2 = vld1q_f32(C_ptr); + cv = vmlaq_f32(cv, _beta, cv2); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv = vmulq_f32(_alpha, cv); + cv2 = vld1q_f32(C_ptr); + cv = vmlaq_f32(cv, _beta, cv2); + switch (_nc1) { + case 3: + vst1q_lane_f32(C_ptr + 2, cv, 2); + case 2: + vst1_f32(C_ptr, vget_low_f32(cv)); + break; + case 1: + vst1q_lane_f32(C_ptr, cv, 0); + break; + } + } + } +} + +#if __aarch64__ +void write_back_alpha1_beta0(const int mc, const int nc, const float *c, + const int ldc1, float *C, const int ldc2) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + const float *c_ptr; + float *C_ptr; + float32x4_t cv; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * ldc1; + C_ptr = C + i * ldc2; + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + switch (_nc1) { + case 3: + vst1q_lane_f32(C_ptr + 2, cv, 2); + case 2: + vst1_f32(C_ptr, vget_low_f32(cv)); + break; + case 1: + vst1q_lane_f32(C_ptr, cv, 0); + break; + } + } + } +} + +void write_back_alpha1_beta1(const int mc, const int nc, const float *c, + const int ldc1, float *C, const int ldc2) { + int nc1 = nc / 4; + int _nc1 = nc % 4; + + const float *c_ptr; + float *C_ptr; + float32x4_t cv, cv2; + for (int i = 0; i < mc; ++i) { + c_ptr = c + i * ldc1; + C_ptr = C + i * ldc2; + for (int j = 0; j < nc1; ++j) { + cv = vld1q_f32(c_ptr); + cv2 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv2); + vst1q_f32(C_ptr, cv); + c_ptr += 4; + C_ptr += 4; + } + if (_nc1 != 0) { + cv = vld1q_f32(c_ptr); + cv2 = vld1q_f32(C_ptr); + cv = vaddq_f32(cv, cv2); + switch (_nc1) { + case 3: + vst1q_lane_f32(C_ptr + 2, cv, 2); + case 2: + vst1_f32(C_ptr, vget_low_f32(cv)); + break; + case 1: + vst1q_lane_f32(C_ptr, cv, 0); + break; + } + } + } +} + +#else +void write_back_alpha1_beta0(const int mc, const int nc, const float *c, + const int ldc1, float *C, const int ldc2) { + int nc1 = nc / 16; + int nc2 = nc % 16; + int step1 = 4 * (ldc1 - 16 * nc1); + int step2 = 4 * ldc2; + int volatile m = mc; + + const float *volatile c_ptr = c; + float *volatile C_ptr = C; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" + + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" + + "vld1.32 {q2, q3}, [%[c_ptr]]! \n\t" + "vst1.32 {q2, q3}, [r6]! \n\t" + + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "add %[C_ptr], %[C_ptr], %[step2] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step1] "r"(step1), [step2] "r"(step2) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); + } + + if (nc2 != 0) { + for (int i = 0; i < mc; i++) { + const float *c0 = c_ptr + nc1 * 16 + i * ldc1; + float *C0 = C_ptr + nc1 * 16 + i * ldc2; + for (int j = 0; j < nc2; j++) { + *C0++ = *c0++; + } + } + } +} + +void write_back_alpha1_beta1(const int mc, const int nc, const float *c, + const int ldc1, float *C, const int ldc2) { + int nc1 = nc / 16; + int nc2 = nc % 16; + int step1 = 4 * (ldc1 - 16 * nc1); + int step2 = 4 * ldc2; + int volatile m = mc; + + const float *volatile c_ptr = c; + float *volatile C_ptr = C; + if (nc1 > 0) { + asm volatile( + "subs %[mc], %[mc], #1 \n\t" + "blt end_mc_%= \n\t" + "loop_mc_%=: \n\t" + + "mov r6, %[C_ptr] \n\t" + "mov r5, %[nc1] \n\t" + "subs r5, r5, #1 \n\t" + "blt end_nc1_%= \n\t" + "loop_nc1_%=: \n\t" + + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vld1.32 {q2, q3}, [r6] \n\t" + "vadd.f32 q0, q0, q2 \n\t" + "vadd.f32 q1, q1, q3 \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" + + "vld1.32 {q0, q1}, [%[c_ptr]]! \n\t" + "vld1.32 {q2, q3}, [r6] \n\t" + "vadd.f32 q0, q0, q2 \n\t" + "vadd.f32 q1, q1, q3 \n\t" + "vst1.32 {q0, q1}, [r6]! \n\t" + + "subs r5, r5, #1 \n\t" + "bge loop_nc1_%= \n\t" + "end_nc1_%=: \n\t" + + "add %[c_ptr], %[c_ptr], %[step1] \n\t" + "add %[C_ptr], %[C_ptr], %[step2] \n\t" + "subs %[mc], %[mc], #1 \n\t" + "bge loop_mc_%= \n\t" + "end_mc_%=: \n\t" + : + : [C_ptr] "r"(C_ptr), [c_ptr] "r"(c_ptr), [mc] "r"(m), [nc1] "r"(nc1), + [step1] "r"(step1), [step2] "r"(step2) + : "memory", "r5", "r6", "q0", "q1", "q2", "q3"); + } + + if (nc2 != 0) { + for (int i = 0; i < mc; i++) { + const float *c0 = c_ptr + nc1 * 16 + i * ldc1; + float *C0 = C_ptr + nc1 * 16 + i * ldc2; + for (int j = 0; j < nc2; j++) { + *C0++ += *c0++; + } + } + } +} +#endif // __aarch64__ + +void write_back(const int mc, const int nc, const float alpha, const float *c, + const int ldc1, const float beta, float *C, const int ldc2) { + if (alpha == 1.f && beta == 0.f) { + write_back_alpha1_beta0(mc, nc, c, ldc1, C, ldc2); + } else if (alpha == 1.f && beta == 1.f) { + write_back_alpha1_beta1(mc, nc, c, ldc1, C, ldc2); + } else { + write_back_alpha_beta(mc, nc, alpha, c, ldc1, beta, C, ldc2); + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile + +#endif // __ARM_NEON__ diff --git a/src/operators/math/gemm/strategy.h b/src/operators/math/gemm/strategy.h new file mode 100644 index 0000000000000000000000000000000000000000..11e24fb1c31ae6a6e02422dea95de2874ccebc5f --- /dev/null +++ b/src/operators/math/gemm/strategy.h @@ -0,0 +1,120 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "operators/math/gemm/gemm_kernel.h" +#include "operators/math/gemm/pack_kernel.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +struct SgemmStrategy { + typedef float Itype; + typedef float Otype; + + typedef void (*packLhsFunc)(const int, const int, const Itype *, const int, + Itype *, const bool); + typedef void (*packRhsFunc)(const int, const int, const Itype *, const int, + Itype *, const bool); + typedef void (*kernelFunc)(const Itype *, const Itype *, const int, Otype *, + const int); + typedef void (*WriteFunc)(const int, const int, const float alpha, + const Otype *, const int, const float beta, Otype *, + const int); + + packLhsFunc pack_lhs; + packRhsFunc pack_rhs; + kernelFunc kernel; + WriteFunc write; + + static int out_width() { +#if __aarch64__ + return 16; +#else + return 8; +#endif + } + + static int out_height() { return 6; } + + SgemmStrategy() { + pack_lhs = pack_lhs_6r; +#if __aarch64__ + pack_rhs = pack_rhs_16c; + kernel = sgemm_6x16; +#else + pack_rhs = pack_rhs_8c; + kernel = sgemm_6x8; +#endif + write = write_back; + } +}; + +struct I8o32gemmStrategy { + typedef int8_t Itype; + typedef int32_t Otype; + + typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *, + const int); + kern_type kernel; + + static int out_width() { return 8; } + + static int out_height() { +#if __aarch64__ + return 12; +#else + return 6; +#endif + } + + I8o32gemmStrategy() {} +}; + +struct SgemvStrategy { + typedef float Itype; + typedef float Otype; + + typedef void (*kernelFunc)(const bool, const int, const int, const float, + const Itype *, const int, const Itype *, + const float, Otype *); + kernelFunc kernel; + + SgemvStrategy() { kernel = sgemv_mx1; } +}; + +struct I8o32gemvStrategy { + typedef int8_t Itype; + typedef int32_t Otype; + + typedef void (*kern_type)(const Itype *, const Itype *, const int, Otype *, + const int); + kern_type kernel; + + static int out_width() { return 1; } + + static int out_height() { +#if __aarch64__ + return 12; +#else + return 6; +#endif + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gru_compute.cpp b/src/operators/math/gru_compute.cpp index 19c7a2685c347340a0d3bd10b1c5828bfd437d4f..d30ea5aa47141d2f4398a36ce6ec7a8885110196 100644 --- a/src/operators/math/gru_compute.cpp +++ b/src/operators/math/gru_compute.cpp @@ -17,7 +17,7 @@ limitations under the License. */ #include "operators/math/gru_compute.h" #include "common/types.h" #include "operators/math/activation.h" -#include "operators/math/gemm.h" +#include "operators/math/gemm/cblas.h" #include "operators/math/gru_cpu_kernel.h" namespace paddle_mobile { @@ -29,35 +29,19 @@ struct GRUUnitFunctor { static void compute(GRUMetaValue value, int frame_size, int batch_size, const ActivationType active_node, const ActivationType active_gate) { - Gemm gemm; if (value.prev_out_value) { -#ifdef _OPENMP - gemm.Sgemm_omp(batch_size, frame_size * 2, frame_size, 1, - value.prev_out_value, frame_size, value.gate_weight, - frame_size * 2, 1, value.gate_value, frame_size * 3, false, - static_cast(nullptr)); -#else - gemm.Sgemm(batch_size, frame_size * 2, frame_size, 1, - value.prev_out_value, frame_size, value.gate_weight, - frame_size * 2, 1, value.gate_value, frame_size * 3, false, - static_cast(nullptr)); -#endif + cblas_sgemm(false, false, batch_size, frame_size * 2, frame_size, 1.f, + value.prev_out_value, frame_size, value.gate_weight, + frame_size * 2, 1.f, value.gate_value, frame_size * 3); } forward_reset_output(value, frame_size, batch_size, active_gate); if (value.prev_out_value) { -#ifdef _OPENMP - gemm.Sgemm_omp(batch_size, frame_size, frame_size, 1, - value.reset_output_value, frame_size, value.state_weight, - frame_size, 1, value.gate_value + frame_size * 2, - frame_size * 3, false, static_cast(nullptr)); -#else - gemm.Sgemm(batch_size, frame_size, frame_size, 1, - value.reset_output_value, frame_size, value.state_weight, - frame_size, 1, value.gate_value + frame_size * 2, - frame_size * 3, false, static_cast(nullptr)); -#endif + cblas_sgemm(false, false, batch_size, frame_size, frame_size, 1.f, + value.reset_output_value, frame_size, value.state_weight, + frame_size, 1.f, value.gate_value + frame_size * 2, + frame_size * 3); } forward_final_output(value, frame_size, batch_size, active_node); @@ -65,6 +49,7 @@ struct GRUUnitFunctor { }; template struct GRUUnitFunctor; + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/im2col.cpp b/src/operators/math/im2col.cpp index 9449ad70819f2ea114fac8848f6ee023871d47f2..fedd17ed0cd95f50348e40207b0dde9243c0d714 100644 --- a/src/operators/math/im2col.cpp +++ b/src/operators/math/im2col.cpp @@ -44,7 +44,17 @@ void ExtractToImg(const float *im_data, float *col_data, const int im_height, for (int i = start_height; i < end_height; i += stride_h) { if (stride_w == 1) { - memcpy(col_data, im_data, extract * sizeof(float)); + // memcpy(col_data, im_data, extract * sizeof(float)); + int s = 0; +#if __ARM_NEON + for (; s < extract - 3; s += 4) { + float32x4_t img = vld1q_f32(im_data + s); + vst1q_f32(col_data + s, img); + } +#endif + for (; s < extract; ++s) { + col_data[s] = im_data[s]; + } } else if (stride_w == 2) { int s = 0; #if __ARM_NEON @@ -109,325 +119,7 @@ void Im2ColFunctor::operator()( const float *im_data = im.data(); float *col_data = col->data(); #if __ARM_NEON - const int osize = col_height; - const int isize = im_height; - bool pad1 = padding[0] > 0; - bool pad2 = - (pad1 && padding[1] && - (((isize - 2 * padding[0] + filter_height) % stride[0] == 0) ? 1 : 0)); - int fill = isize % 2; - if (stride[0] == 1 && filter_height == 3 && pad1 && pad2 && - dilation[0] == 1 && im_height > 2 && im_height == im_width) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - float *col0 = col_data + 0 * oosize + 2 * osize + 2; - float *col1 = col_data + 1 * oosize + 2 * osize + 1; - float *col2 = col_data + 2 * oosize + 2 * osize; - - float *col3 = col_data + 3 * oosize + osize + 2; - float *col4 = col_data + 4 * oosize + osize + 1; - float *col5 = col_data + 5 * oosize + osize; - - float *col6 = col_data + 6 * oosize + 2; - float *col7 = col_data + 7 * oosize + 1; - float *col8 = col_data + 8 * oosize; - - float32x4_t im1; - const float *im_tmp_data = im_data + osize + 1; - - int rrsize = oosize - osize - 1; - int nr4 = rrsize / 4; - int mr4 = rrsize % 4; - for (int i = 0; i < nr4; ++i) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - vst1q_f32(col6, im1); - vst1q_f32(col7, im1); - vst1q_f32(col8, im1); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - col6 += 4; - col7 += 4; - col8 += 4; - - im_tmp_data += 4; - } - for (int i = 0; i < mr4; ++i) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - *col6 = *im_tmp_data; - *col7 = *im_tmp_data; - *col8 = *im_tmp_data; - - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - col6++; - col7++; - col8++; - - im_tmp_data++; - } - - im_tmp_data = im_data + 1; - col0 = col_data + 0 * oosize + osize + 2; - col1 = col_data + 1 * oosize + osize + 1; - col2 = col_data + 2 * oosize + osize; - - col3 = col_data + 3 * oosize + 2; - col4 = col_data + 4 * oosize + 1; - col5 = col_data + 5 * oosize; - - for (int i = 0; i < nk4; i++) { - im1 = vld1q_f32(im_tmp_data); - vst1q_f32(col0, im1); - vst1q_f32(col1, im1); - vst1q_f32(col2, im1); - vst1q_f32(col3, im1); - vst1q_f32(col4, im1); - vst1q_f32(col5, im1); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - im_tmp_data += 4; - } - - for (int i = 0; i < mk4; i++) { - *col0 = *im_tmp_data; - *col1 = *im_tmp_data; - *col2 = *im_tmp_data; - *col3 = *im_tmp_data; - *col4 = *im_tmp_data; - *col5 = *im_tmp_data; - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - - im_tmp_data++; - } - - // fill 0 1 11; - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - - col_data[2 * oosize + osize - 1 + i * osize] = 0.0; - col_data[5 * oosize + osize - 1 + i * osize] = 0.0; - col_data[8 * oosize + osize - 1 + i * osize] = 0.0; - } - - col_data[0 * oosize + osize + 1] = im_data[0]; - col_data[3 * oosize + 1] = im_data[0]; - col_data[6 * oosize + 1] = im_data[osize]; - - col_data[1 * oosize + osize] = im_data[0]; - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[osize]; - - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); - vst1q_f32(col_z6, zero4); - vst1q_f32(col_z7, zero4); - vst1q_f32(col_z8, zero4); - - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; - } - - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; - col_z6[i] = 0.0; - col_z7[i] = 0.0; - col_z8[i] = 0.0; - } - col_data += 9 * oosize; - im_data += isize * isize; - } - } else if (stride[0] == 2 && filter_height == 3 && pad1 && dilation[0] == 1 && - im_height > 2 && im_height == im_width) { - for (int c = 0; c < im_channels; ++c) { - int oosize = osize * osize; - int nk4 = osize / 4; - int mk4 = osize % 4; - - // 3 2 3 1 0 1 3 2 3 - float *col0 = col_data + 0 * oosize + osize + 1; - float *col1 = col_data + 1 * oosize + osize; - float *col2 = col_data + 2 * oosize + osize; - - float *col3 = col_data + 3 * oosize + 1; - float *col4 = col_data + 4 * oosize; - float *col5 = col_data + 5 * oosize; - - float *col6 = col_data + 6 * oosize + 1; - float *col7 = col_data + 7 * oosize; - float *col8 = col_data + 8 * oosize; - - float32x4x2_t im01; - float32x4x2_t im23; - const float *im_tmp_data0 = im_data; - const float *im_tmp_data2 = im_data + isize; - - for (int j = 0; j < osize; ++j) { - for (int i = 0; i < nk4; ++i) { - im01 = vld2q_f32(im_tmp_data0); - im23 = vld2q_f32(im_tmp_data2); - vst1q_f32(col0, im23.val[1]); - vst1q_f32(col1, im23.val[0]); - vst1q_f32(col2, im23.val[1]); - vst1q_f32(col3, im01.val[1]); - vst1q_f32(col4, im01.val[0]); - vst1q_f32(col5, im01.val[1]); - vst1q_f32(col6, im23.val[1]); - vst1q_f32(col7, im23.val[0]); - vst1q_f32(col8, im23.val[1]); - - col0 += 4; - col1 += 4; - col2 += 4; - col3 += 4; - col4 += 4; - col5 += 4; - col6 += 4; - col7 += 4; - col8 += 4; - - im_tmp_data0 += 8; - im_tmp_data2 += 8; - } - const float *im_tmp_data1 = im_tmp_data0 + 1; - const float *im_tmp_data3 = im_tmp_data2 + 1; - for (int i = 0; i < mk4; ++i) { - *col0 = *im_tmp_data3; - *col1 = *im_tmp_data2; - *col2 = *im_tmp_data3; - *col3 = *im_tmp_data1; - *col4 = *im_tmp_data0; - *col5 = *im_tmp_data1; - *col6 = *im_tmp_data3; - *col7 = *im_tmp_data2; - *col8 = *im_tmp_data3; - - col0++; - col1++; - col2++; - col3++; - col4++; - col5++; - col6++; - col7++; - col8++; - im_tmp_data0 += 2; - im_tmp_data1 += 2; - im_tmp_data2 += 2; - im_tmp_data3 += 2; - } - im_tmp_data0 += (isize - fill); - im_tmp_data2 += (isize - fill); - } - for (int i = 0; i < osize; ++i) { - col_data[0 * oosize + i * osize] = 0.0; - col_data[3 * oosize + i * osize] = 0.0; - col_data[6 * oosize + i * osize] = 0.0; - if (pad2) { - col_data[2 * oosize + osize - 1 + i * osize] = 0.0; - col_data[5 * oosize + osize - 1 + i * osize] = 0.0; - col_data[8 * oosize + osize - 1 + i * osize] = 0.0; - } - } - float32x4_t zero4; - zero4 = vdupq_n_f32(0.0); - auto col_z0 = col_data; - auto col_z1 = col_data + oosize; - auto col_z2 = col_data + 2 * oosize; - auto col_z6 = col_data + 6 * oosize + osize * (osize - 1); - auto col_z7 = col_data + 7 * oosize + osize * (osize - 1); - auto col_z8 = col_data + 8 * oosize + osize * (osize - 1); - - for (int i = 0; i < nk4; ++i) { - vst1q_f32(col_z0, zero4); - vst1q_f32(col_z1, zero4); - vst1q_f32(col_z2, zero4); - if (pad2) { - vst1q_f32(col_z6, zero4); - vst1q_f32(col_z7, zero4); - vst1q_f32(col_z8, zero4); - } - col_z0 += 4; - col_z1 += 4; - col_z2 += 4; - col_z6 += 4; - col_z7 += 4; - col_z8 += 4; - } - - for (int i = 0; i < mk4; ++i) { - col_z0[i] = 0.0; - col_z1[i] = 0.0; - col_z2[i] = 0.0; - if (pad2) { - col_z6[i] = 0.0; - col_z7[i] = 0.0; - col_z8[i] = 0.0; - } - } - - col_data[1 * oosize + osize] = im_data[isize]; - for (int i = 1; i < osize; ++i) { - col_data[3 * oosize + i] = im_data[(i - 1) * stride[0] + 1]; - } - col_data[4 * oosize] = im_data[0]; - col_data[7 * oosize] = im_data[isize]; - - col_data += 9 * oosize; - im_data += isize * isize; - } - } else if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { + if (stride[0] <= 4 && dilation[0] == 1 && dilation[0] == dilation[1]) { int im_spatial_size = im_height * im_width; int col_spatial_size = col_height * col_width; // pad 0 diff --git a/src/operators/math/math_func_neon.h b/src/operators/math/math.h similarity index 96% rename from src/operators/math/math_func_neon.h rename to src/operators/math/math.h index 3f9245351d3bce49f852b90a4d14bab7e6a826f5..8ff5019e318fa996a388f93cdda0efc0024fe0ee 100644 --- a/src/operators/math/math_func_neon.h +++ b/src/operators/math/math.h @@ -327,4 +327,16 @@ static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) { return exp_ps(vmulq_f32(b, log_ps(a))); } +#ifndef __aarch64__ +inline float32x4_t vpaddq_f32(float32x4_t r0, float32x4_t r1) { + float32x2_t sum0 = vpadd_f32(vget_low_f32(r0), vget_high_f32(r0)); + float32x2_t sum1 = vpadd_f32(vget_low_f32(r1), vget_high_f32(r1)); + return vcombine_f32(sum0, sum1); +} +#endif + +inline float32x4_t vandq_f32_u32(float32x4_t x, uint32x4_t mask) { + return vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask)); +} + #endif // __ARM_NEON__ diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index b1e49e377b661cdfdefe08e8043f11b43ab0f9ee..4a35dd2a57f3bab262dd28a60bac63cb7b7e8f77 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -18,6 +18,7 @@ limitations under the License. */ #include "framework/data_type.h" #include "framework/tensor.h" #include "operators/math/gemm.h" +#include "operators/math/gemm/cblas.h" namespace paddle_mobile { namespace operators { @@ -55,6 +56,7 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, int M = dim_out[0]; int N = dim_out[1]; int K = (!trans_a) ? dim_a[1] : dim_a[0]; + Gemm gemm; if (trans_a) { framework::Tensor matrix_trans; @@ -69,24 +71,11 @@ void MatMul(const framework::Tensor &matrix_a, bool trans_a, a[index++] = tmp[i * n + j]; } } - -#ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); -#else - gemm.Sgemm(M, N, K, alpha, a, K, matrix_b.data(), N, beta, - matrix_out->data(), N, relu, bias); -#endif + cblas_sgemm(false, false, M, N, K, alpha, a, K, matrix_b.data(), N, + beta, matrix_out->data(), N); } else { -#ifdef _OPENMP - gemm.Sgemm_omp(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), - N, relu, bias); -#else - gemm.Sgemm(M, N, K, alpha, matrix_a.data(), K, - matrix_b.data(), N, beta, matrix_out->data(), N, - relu, bias); -#endif + cblas_sgemm(false, false, M, N, K, alpha, matrix_a.data(), K, + matrix_b.data(), N, beta, matrix_out->data(), N); } } diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index 6b34f522ff6caf32c20971d9cf38f93730fdb727..e066b0cccddf9a43953182788508aca4769fcd27 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -19,7 +19,7 @@ limitations under the License. */ #include #include #include "common/types.h" -#include "operators/math/math_func_neon.h" +#include "operators/math/math.h" namespace paddle_mobile { namespace operators { diff --git a/src/operators/math/winograd/winograd_transform_f6k3.cpp b/src/operators/math/winograd/winograd_transform_f6k3.cpp index 937050ebbd8f6bab3b0c9b075e9b4fa54c25b1ba..234de599ad2cab72c471176330bc3c0aacd02d5f 100644 --- a/src/operators/math/winograd/winograd_transform_f6k3.cpp +++ b/src/operators/math/winograd/winograd_transform_f6k3.cpp @@ -15,10 +15,10 @@ limitations under the License. */ // Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn // project. +#if defined(__ARM_NEON) || defined(__ARM_NEON__) #ifdef CONV_OP -#ifndef __aarch64__ - +#include #include "operators/math/pad.h" #include "operators/math/winograd/winograd_transform.h" @@ -51,10 +51,12 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, const float transform_matrix[8] = {2.f, -2.f / 9, 1.f / 90, 1.f / 180}; const float *inptr = weight.data(); - int remain_start = out_channel & 0xFFFC; -#if 0 - remain_start = 0; + +#if __aarch64__ + int remain_start = 0; #else + int remain_start = out_channel & 0xFFFC; + #pragma omp parallel for for (int oc = 0; oc < out_channel - 3; oc += 4) { float gw[96]; // gw[3][8][4] @@ -258,7 +260,7 @@ void winograd_transform_weight<8, 3>(const framework::Tensor &weight, "q13", "r0"); } } -#endif +#endif // __aarch64__ // remain output channel #pragma omp parallel for @@ -350,311 +352,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, size_t image_size = height * width; const float transform_matrix[8] = {5.25f, -5.f, -4.25f, -2.5f, 2.f, -1.25f, 0.5f, 0.25f}; - int remain_c_start = channel & 0xFFFC; -#if 1 - remain_c_start = 0; -#else - #pragma omp parallel for - for (int c = 0; c < channel - 3; c += 4) { - const float *in = inptr + c * image_size; - float d_bt[64 * 4]; // d * B_t - for (int h = 0; h < h_tiles; ++h) { - for (int w = 0; w < w_tiles; ++w) { - const float *in0 = in + (h * width + w) * 6; - const float *in1 = in0 + image_size; - const float *in2 = in1 + image_size; - const float *in3 = in2 + image_size; - int steps = width * sizeof(float); - float *d_bt_ptr = d_bt; - asm volatile( - "mov r0, #8 \n" - "vld1.32 {d0-d3}, [%[tm_ptr]] \n" - // row loop - "loop_r_%=: \n" - "vld1.32 {d4-d7}, [%[in0]], %[steps] \n" - "vld1.32 {d8-d11}, [%[in1]], %[steps] \n" - "vld1.32 {d12-d15}, [%[in2]], %[steps] \n" - "vld1.32 {d16-d19}, [%[in3]], %[steps] \n" - "vtrn.32 q2, q4 \n" // d0: q2 - "vtrn.32 q3, q5 \n" // d1: q4 - "vtrn.32 q6, q8 \n" // d2: q6 - "vtrn.32 q7, q9 \n" // d3: q8 - "vswp.32 d5, d12 \n" // d4: q3 - "vswp.32 d9, d16 \n" // d5: q5 - "vswp.32 d7, d14 \n" // d6: q7 - "vswp.32 d11, d18 \n" // d7: q9 - - "vsub.f32 q10, q2, q7 \n" - "vsub.f32 q11, q3, q6 \n" - "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - - // d2) * 5.25 - "vst1.32 {d20-d21}, [%[d_bt]]! \n" - - "vadd.f32 q10, q6, q7 \n" - "vadd.f32 q11, q4, q5 \n" - "vmla.f32 q10, q3, d1[0] \n" // d2 - 4.25 * d4 + - // d6 - "vmla.f32 q11, q8, d1[0] \n" // d1 - 4.25 * d3 + - // d5 - "vadd.f32 q12, q10, q11 \n" - "vsub.f32 q13, q10, q11 \n" - "vst1.32 {d24-d27}, [%[d_bt]]! \n" - - "vmul.f32 q10, q6, d3[1] \n" // 0.25 * d2 - "vmul.f32 q11, q4, d3[0] \n" // 0.5 * d1 - "vadd.f32 q10, q10, q7 \n" // 0.25 * d2 + d6 - "vmla.f32 q11, q5, d2[0] \n" // 0.5 * d1 + 2 * - // d5 - "vmla.f32 q10, q3, d2[1] \n" // 0.25 * d2 + d6 - // - 1.25 * d4 - "vmla.f32 q11, q8, d1[1] \n" // 0.5 * d1 + 2 * - // d5 - 2.5 * d3 - "vadd.f32 q12, q10, q11 \n" - "vsub.f32 q13, q10, q11 \n" - "vst1.32 {d24-d27}, [%[d_bt]]! \n" - - "vmul.f32 q10, q6, d2[0] \n" // 2 * d2 - "vmul.f32 q11, q4, d2[0] \n" // 2 * d1 - "vmla.f32 q10, q3, d1[1] \n" // 2 * d2 - 2.5 * - // d4 - "vmla.f32 q11, q8, d1[1] \n" // 2 * d1 - 2.5 * - // d3 - "vmla.f32 q10, q7, d3[0] \n" // 2 * d1 - 2.5 * - // d3 + 0.5 * d6 - "vmla.f32 q11, q5, d3[0] \n" // 2 * d2 - 2.5 * - // d4 + 0.5 * d5 - "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 - // + d6 - "vadd.f32 q12, q10, q11 \n" - "vsub.f32 q13, q10, q11 \n" - "vst1.32 {d24-d27}, [%[d_bt]]! \n" - - "vsub.f32 q10, q9, q4 \n" - "vsub.f32 q11, q8, q5 \n" - "vmla.f32 q10, q11, d0[0] \n" - "vst1.32 {d20-d21}, [%[d_bt]]! \n" - - "subs r0, #1 \n" - "bne loop_r_%= \n" - : [d_bt] "+r"(d_bt_ptr), [in0] "+r"(in0), [in1] "+r"(in1), - [in2] "+r"(in2), [in3] "+r"(in3) - : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "r0"); - - float *ptr0 = d_bt; - float *ptr1 = ptr0 + 32; - float *ptr2 = ptr1 + 32; - float *ptr3 = ptr2 + 32; - float *ptr4 = ptr3 + 32; - float *ptr5 = ptr4 + 32; - float *ptr6 = ptr5 + 32; - float *ptr7 = ptr6 + 32; - int tile_indics = h * w_tiles + w; - int tile_block = tile_indics >> 3; - int block_indics = tile_indics & 0x7; - // (tiles / 8, 64, channel, 8) - float *out0 = - outptr + (tile_block * 64 * channel + c) * 8 + block_indics; - steps = (channel - 3) * 8 * sizeof(float); - asm volatile( - "vld1.32 {d0-d3}, [%[tm_ptr]] \n" - "mov r0, 4 \n" - "mov r1, 32 \n" - "loop_col_%=: \n" - // col 0: - "vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0 - "vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1 - "vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2 - "vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3 - "vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4 - "vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5 - "vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6 - "vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7 - - "vsub.f32 q10, q2, q8 \n" // d0 - d6 - "vsub.f32 q11, q6, q4 \n" // d4 - d2 - "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - - // d2) * 5.25 - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - "vadd.f32 q10, q4, q8 \n" - "vadd.f32 q11, q3, q7 \n" - "vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 + - // d6 - "vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 + - // d5 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2 - "vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1 - "vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6 - "vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 * - // d5 - "vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6 - // - 1.25 * d4 - "vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 * - // d5 - 2.5 * d3 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d2[0] \n" // 2 * d2 - "vmul.f32 q11, q3, d2[0] \n" // 2 * d1 - "vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 * - // d4 - "vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 * - // d3 - "vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 * - // d3 + 0.5 * d6 - "vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 * - // d4 + 0.5 * d5 - "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 - // + d6 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q10, q9, q3 \n" - "vsub.f32 q11, q5, q7 \n" - "vmla.f32 q10, q11, d0[0] \n" - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - // col 1: - "vld1.32 {d4-d5}, [%[ptr0]]! \n" // q2: d0 - "vld1.32 {d6-d7}, [%[ptr1]]! \n" // q3: d1 - "vld1.32 {d8-d9}, [%[ptr2]]! \n" // q4: d2 - "vld1.32 {d10-d11}, [%[ptr3]]! \n" // q5: d3 - "vld1.32 {d12-d13}, [%[ptr4]]! \n" // q6: d4 - "vld1.32 {d14-d15}, [%[ptr5]]! \n" // q7: d5 - "vld1.32 {d16-d17}, [%[ptr6]]! \n" // q8: d6 - "vld1.32 {d18-d19}, [%[ptr7]]! \n" // q9: d7 - - "vsub.f32 q10, q2, q8 \n" // d0 - d6 - "vsub.f32 q11, q6, q4 \n" // d4 - d2 - "vmla.f32 q10, q11, d0[0] \n" // d0 - d6 + (d4 - - // d2) * 5.25 - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - "vadd.f32 q10, q4, q8 \n" - "vadd.f32 q11, q3, q7 \n" - "vmla.f32 q10, q6, d1[0] \n" // d2 - 4.25 * d4 + - // d6 - "vmla.f32 q11, q5, d1[0] \n" // d1 - 4.25 * d3 + - // d5 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d3[1] \n" // 0.25 * d2 - "vmul.f32 q11, q3, d3[0] \n" // 0.5 * d1 - "vadd.f32 q10, q10, q8 \n" // 0.25 * d2 + d6 - "vmla.f32 q11, q7, d2[0] \n" // 0.5 * d1 + 2 * - // d5 - "vmla.f32 q10, q6, d2[1] \n" // 0.25 * d2 + d6 - // - 1.25 * d4 - "vmla.f32 q11, q5, d1[1] \n" // 0.5 * d1 + 2 * - // d5 - 2.5 * d3 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vmul.f32 q10, q4, d2[0] \n" // 2 * d2 - "vmul.f32 q11, q3, d2[0] \n" // 2 * d1 - "vmla.f32 q10, q6, d1[1] \n" // 2 * d2 - 2.5 * - // d4 - "vmla.f32 q11, q5, d1[1] \n" // 2 * d1 - 2.5 * - // d3 - "vmla.f32 q10, q8, d3[0] \n" // 2 * d1 - 2.5 * - // d3 + 0.5 * d6 - "vmla.f32 q11, q7, d3[0] \n" // 2 * d2 - 2.5 * - // d4 + 0.5 * d5 - "vmul.f32 q10, q10, d2[0] \n" // 4 * d1 - 5 * d3 - // + d6 - "vadd.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q12, q10, q11 \n" - "vst1.32 {d24[0]}, [%[out0]], r1 \n" - "vst1.32 {d24[1]}, [%[out0]], r1 \n" - "vst1.32 {d25[0]}, [%[out0]], r1 \n" - "vst1.32 {d25[1]}, [%[out0]], %[steps] \n" - - "vsub.f32 q10, q9, q3 \n" - "vsub.f32 q11, q5, q7 \n" - "vmla.f32 q10, q11, d0[0] \n" - "vst1.32 {d20[0]}, [%[out0]], r1 \n" - "vst1.32 {d20[1]}, [%[out0]], r1 \n" - "vst1.32 {d21[0]}, [%[out0]], r1 \n" - "vst1.32 {d21[1]}, [%[out0]], %[steps] \n" - - "subs r0, #1 \n" - "bne loop_col_%= \n" - : [out0] "+r"(out0), [ptr0] "+r"(ptr0), [ptr1] "+r"(ptr1), - [ptr2] "+r"(ptr2), [ptr3] "+r"(ptr3), [ptr4] "+r"(ptr4), - [ptr5] "+r"(ptr5), [ptr6] "+r"(ptr6), [ptr7] "+r"(ptr7) - : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) - : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", - "q8", "q9", "q10", "q11", "q12", "q13", "r0", "r1"); - } - } - } -#endif - - // remainer channels #pragma omp parallel for - for (int c = remain_c_start; c < channel; ++c) { + for (int c = 0; c < channel; ++c) { const float *in = inptr + c * image_size; float d_bt[64]; // d * B_t for (int h = 0; h < h_tiles; ++h) { @@ -664,6 +363,90 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, const float *in2 = in1 + width; const float *in3 = in2 + width; float *d_bt_ptr = d_bt; +#if __aarch64__ + int steps = 4 * width; + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4_t _q1 = vld1q_f32(transform_matrix + 4); + for (int l = 0; l < 2; ++l) { + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(in0); + _q45.val[0] = vld1q_f32(in0 + 4); + _q23.val[1] = vld1q_f32(in1); + _q45.val[1] = vld1q_f32(in1 + 4); + _q67.val[0] = vld1q_f32(in2); + _q89.val[0] = vld1q_f32(in2 + 4); + _q67.val[1] = vld1q_f32(in3); + _q89.val[1] = vld1q_f32(in3 + 4); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q67.val[0])); + float32x4_t _q4 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q67.val[1])); + float32x4_t _q3 = vcombine_f32(vget_low_f32(_q45.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q45.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q6 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q67.val[0])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q67.val[1])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q45.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q9 = vcombine_f32(vget_high_f32(_q45.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q10 = vsubq_f32(_q2, _q7); + float32x4_t _q11 = vsubq_f32(_q3, _q6); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_f32(d_bt_ptr, _q10); + + _q10 = vaddq_f32(_q6, _q7); + _q11 = vaddq_f32(_q4, _q5); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 0); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 0); + float32x4_t _q12 = vaddq_f32(_q10, _q11); + float32x4_t _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 4, _q12); + vst1q_f32(d_bt_ptr + 8, _q13); + + _q10 = vmulq_lane_f32(_q6, vget_high_f32(_q1), 1); + _q11 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 0); + _q10 = vaddq_f32(_q10, _q7); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_low_f32(_q1), 1); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1); + _q12 = vaddq_f32(_q10, _q11); + _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 12, _q12); + vst1q_f32(d_bt_ptr + 16, _q13); + + _q10 = vmulq_lane_f32(_q6, vget_low_f32(_q1), 0); + _q11 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q3, vget_high_f32(_q0), 1); + _q11 = vmlaq_lane_f32(_q11, _q8, vget_high_f32(_q0), 1); + _q10 = vmlaq_lane_f32(_q10, _q7, vget_high_f32(_q1), 0); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q1), 0); + _q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0); + _q12 = vaddq_f32(_q10, _q11); + _q13 = vsubq_f32(_q10, _q11); + vst1q_f32(d_bt_ptr + 20, _q12); + vst1q_f32(d_bt_ptr + 24, _q13); + + _q10 = vsubq_f32(_q9, _q4); + _q11 = vsubq_f32(_q8, _q5); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_f32(d_bt_ptr + 28, _q10); + + in0 += steps; + in1 += steps; + in2 += steps; + in3 += steps; + d_bt_ptr += 32; + } +#else int steps = 4 * width * sizeof(float); asm volatile( "vld1.32 {d0-d3}, [%[tm_ptr]] \n" @@ -740,7 +523,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "r0"); - +#endif // __aarch64__ float *ptr0 = d_bt; float *ptr1 = ptr0 + 32; int tile_indics = h * w_tiles + w; @@ -756,6 +539,120 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, float *out5 = out4 + channel * 8; float *out6 = out5 + channel * 8; float *out7 = out6 + channel * 8; +#if __aarch64__ + steps = 8 * channel * 8; + for (int l = 0; l < 2; ++l) { + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(ptr0); + _q23.val[1] = vld1q_f32(ptr0 + 4); + _q45.val[0] = vld1q_f32(ptr0 + 8); + _q45.val[1] = vld1q_f32(ptr0 + 12); + _q67.val[0] = vld1q_f32(ptr1); + _q67.val[1] = vld1q_f32(ptr1 + 4); + _q89.val[0] = vld1q_f32(ptr1 + 8); + _q89.val[1] = vld1q_f32(ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q9 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q10 = vsubq_f32(_q2, _q8); + float32x4_t _q11 = vsubq_f32(_q6, _q4); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_lane_f32(out0, _q10, 0); + vst1q_lane_f32(out0 + steps, _q10, 1); + vst1q_lane_f32(out0 + 2 * steps, _q10, 2); + vst1q_lane_f32(out0 + 3 * steps, _q10, 3); + + _q10 = vaddq_f32(_q4, _q8); + _q11 = vaddq_f32(_q3, _q7); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 0); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 0); + float32x4_t _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out1, _q12, 0); + vst1q_lane_f32(out1 + steps, _q12, 1); + vst1q_lane_f32(out1 + 2 * steps, _q12, 2); + vst1q_lane_f32(out1 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out2, _q12, 0); + vst1q_lane_f32(out2 + steps, _q12, 1); + vst1q_lane_f32(out2 + 2 * steps, _q12, 2); + vst1q_lane_f32(out2 + 3 * steps, _q12, 3); + + _q10 = vmulq_lane_f32(_q4, vget_high_f32(_q1), 1); + _q11 = vmulq_lane_f32(_q3, vget_high_f32(_q1), 0); + _q10 = vaddq_f32(_q10, _q8); + _q11 = vmlaq_lane_f32(_q11, _q7, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_low_f32(_q1), 1); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1); + _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out3, _q12, 0); + vst1q_lane_f32(out3 + steps, _q12, 1); + vst1q_lane_f32(out3 + 2 * steps, _q12, 2); + vst1q_lane_f32(out3 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out4, _q12, 0); + vst1q_lane_f32(out4 + steps, _q12, 1); + vst1q_lane_f32(out4 + 2 * steps, _q12, 2); + vst1q_lane_f32(out4 + 3 * steps, _q12, 3); + + _q10 = vmulq_lane_f32(_q4, vget_low_f32(_q1), 0); + _q11 = vmulq_lane_f32(_q3, vget_low_f32(_q1), 0); + _q10 = vmlaq_lane_f32(_q10, _q6, vget_high_f32(_q0), 1); + _q11 = vmlaq_lane_f32(_q11, _q5, vget_high_f32(_q0), 1); + _q10 = vmlaq_lane_f32(_q10, _q8, vget_high_f32(_q1), 0); + _q11 = vmlaq_lane_f32(_q11, _q7, vget_high_f32(_q1), 0); + _q10 = vmulq_lane_f32(_q10, vget_low_f32(_q1), 0); + _q12 = vaddq_f32(_q10, _q11); + vst1q_lane_f32(out5, _q12, 0); + vst1q_lane_f32(out5 + steps, _q12, 1); + vst1q_lane_f32(out5 + 2 * steps, _q12, 2); + vst1q_lane_f32(out5 + 3 * steps, _q12, 3); + + _q12 = vsubq_f32(_q10, _q11); + vst1q_lane_f32(out6, _q12, 0); + vst1q_lane_f32(out6 + steps, _q12, 1); + vst1q_lane_f32(out6 + 2 * steps, _q12, 2); + vst1q_lane_f32(out6 + 3 * steps, _q12, 3); + + _q10 = vsubq_f32(_q9, _q3); + _q11 = vsubq_f32(_q5, _q7); + _q10 = vmlaq_lane_f32(_q10, _q11, vget_low_f32(_q0), 0); + vst1q_lane_f32(out7, _q10, 0); + vst1q_lane_f32(out7 + steps, _q10, 1); + vst1q_lane_f32(out7 + 2 * steps, _q10, 2); + vst1q_lane_f32(out7 + 3 * steps, _q10, 3); + + ptr0 += 16; + ptr1 += 16; + out0 += 4 * steps; + out1 += 4 * steps; + out2 += 4 * steps; + out3 += 4 * steps; + out4 += 4 * steps; + out5 += 4 * steps; + out6 += 4 * steps; + out7 += 4 * steps; + } +#else steps = 8 * channel * 8 * sizeof(float); asm volatile( "mov r0, #2 \n" @@ -861,6 +758,7 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "r0"); +#endif // __aarch64__ } } } @@ -893,6 +791,71 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, const float *in_ptr = input_ptr + (j * 64 + k) * in_channel * 8; int inter_channel = in_channel >> 1; int remain_channel = in_channel & 0x1; +#if __aarch64__ + asm volatile( + "dup v8.4s, wzr \n" + "dup v9.4s, wzr \n" + "dup v10.4s, wzr \n" + "dup v11.4s, wzr \n" + "dup v12.4s, wzr \n" + "dup v13.4s, wzr \n" + "dup v14.4s, wzr \n" + "dup v15.4s, wzr \n" + + "cmp %[inter], #0 \n" + "ble 2f \n" + // loop 2 channels + "1: \n" + "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" + "ld1 {v4.4s, v5.4s}, [%[in_ptr]], #32 \n" + + "fmla v8.4s, v2.4s, v0.s[0] \n" + "fmla v9.4s, v3.4s, v0.s[0] \n" + "fmla v10.4s, v2.4s, v0.s[1] \n" + "fmla v11.4s, v3.4s, v0.s[1] \n" + "fmla v12.4s, v2.4s, v0.s[2] \n" + "fmla v13.4s, v3.4s, v0.s[2] \n" + "fmla v14.4s, v2.4s, v0.s[3] \n" + "fmla v15.4s, v3.4s, v0.s[3] \n" + + "fmla v8.4s, v4.4s, v1.s[0] \n" + "fmla v9.4s, v5.4s, v1.s[0] \n" + "fmla v10.4s, v4.4s, v1.s[1] \n" + "fmla v11.4s, v5.4s, v1.s[1] \n" + "fmla v12.4s, v4.4s, v1.s[2] \n" + "fmla v13.4s, v5.4s, v1.s[2] \n" + "fmla v14.4s, v4.4s, v1.s[3] \n" + "fmla v15.4s, v5.4s, v1.s[3] \n" + + "subs %[inter], %[inter], #1 \n" + "bne 1b \n" + + // loop 1 channel + "2: \n" + "cmp %[remain], #0 \n" + "ble 3f \n" + + "ld1 {v0.4s, v1.4s}, [%[w_ptr]], #32 \n" + "ld1 {v2.4s, v3.4s}, [%[in_ptr]], #32 \n" + "fmla v8.4s, v2.4s, v0.s[0] \n" + "fmla v9.4s, v3.4s, v0.s[0] \n" + "fmla v10.4s, v2.4s, v0.s[1] \n" + "fmla v11.4s, v3.4s, v0.s[1] \n" + "fmla v12.4s, v2.4s, v0.s[2] \n" + "fmla v13.4s, v3.4s, v0.s[2] \n" + "fmla v14.4s, v2.4s, v0.s[3] \n" + "fmla v15.4s, v3.4s, v0.s[3] \n" + + "3: \n" + "st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[uv_ptr]], #64 \n" + "st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[uv_ptr]], #64 \n" + : [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr), + [inter] "+r"(inter_channel) + : [remain] "r"(remain_channel) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"); +#else asm volatile( "veor q8, q8, q8 \n" "veor q9, q9, q9 \n" @@ -957,6 +920,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [remain_channel] "r"(remain_channel) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ } } } @@ -992,6 +956,116 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, int tile_block = tile_indics >> 3; int block_indics = tile_indics & 0x7; const float *uv_ptr0 = uv_ptr + tile_block * 64 * 32 + block_indics; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + for (int l = 0; l < 2; ++l) { + float32x4_t _q1, _q2, _q3, _q4, _q5, _q6, _q7, _q8; + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 0); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 0); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 0); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 0); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 0); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 0); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 0); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 0); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 1); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 1); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 1); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 1); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 1); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 1); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 1); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 1); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 2); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 2); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 2); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 2); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 2); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 2); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 2); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 2); + uv_ptr0 += 32; + + _q1 = vsetq_lane_f32(*uv_ptr0, _q1, 3); + uv_ptr0 += 32; + _q3 = vsetq_lane_f32(*uv_ptr0, _q3, 3); + uv_ptr0 += 32; + _q5 = vsetq_lane_f32(*uv_ptr0, _q5, 3); + uv_ptr0 += 32; + _q7 = vsetq_lane_f32(*uv_ptr0, _q7, 3); + uv_ptr0 += 32; + _q2 = vsetq_lane_f32(*uv_ptr0, _q2, 3); + uv_ptr0 += 32; + _q4 = vsetq_lane_f32(*uv_ptr0, _q4, 3); + uv_ptr0 += 32; + _q6 = vsetq_lane_f32(*uv_ptr0, _q6, 3); + uv_ptr0 += 32; + _q8 = vsetq_lane_f32(*uv_ptr0, _q8, 3); + uv_ptr0 += 32; + + float32x4_t _q9 = vaddq_f32(_q3, _q5); + float32x4_t _q10 = vaddq_f32(_q7, _q2); + float32x4_t _q11 = vaddq_f32(_q4, _q6); + float32x4_t _q12 = vsubq_f32(_q3, _q5); + float32x4_t _q13 = vsubq_f32(_q7, _q2); + float32x4_t _q14 = vsubq_f32(_q4, _q6); + _q2 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q3 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + float32x4_t _q15 = vaddq_f32(_q1, _q9); + _q15 = vaddq_f32(_q15, _q10); + _q15 = vmlaq_lane_f32(_q15, _q3, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr, _q15); + + _q15 = vaddq_f32(_q12, _q2); + _q15 = vmlaq_lane_f32(_q15, _q14, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 4, _q15); + + _q15 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q15 = vmlaq_lane_f32(_q15, _q11, vget_high_f32(_q0), 0); + vst1q_f32(at_m_ptr + 8, _q15); + + _q15 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q15 = vmlaq_lane_f32(_q15, _q14, vget_low_f32(_q0), 1); + vst1q_f32(at_m_ptr + 12, _q15); + + _q15 = vaddq_f32(_q9, _q3); + _q15 = vmlaq_lane_f32(_q15, _q10, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 16, _q15); + + _q15 = vaddq_f32(_q12, _q8); + _q15 = vaddq_f32(_q15, _q14); + _q15 = vmlaq_lane_f32(_q15, _q2, vget_high_f32(_q0), 1); + vst1q_f32(at_m_ptr + 20, _q15); + + at_m_ptr += 24; + } +#else int steps = 32 * sizeof(float); asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" @@ -1077,6 +1151,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix), [steps] "r"(steps) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "r0"); +#endif // __aarch64__ float *at_m_ptr0 = at_m; float *at_m_ptr1 = at_m + 24; @@ -1088,6 +1163,134 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float *out_ptr3 = output_tmp + 18; float *out_ptr4 = output_tmp + 24; float *out_ptr5 = output_tmp + 30; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(at_m_ptr0); + _q23.val[1] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[0] = vld1q_f32(at_m_ptr0 + 8); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); + _q67.val[0] = vld1q_f32(at_m_ptr1); + _q67.val[1] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[0] = vld1q_f32(at_m_ptr1 + 8); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q9 = vaddq_f32(_q2, _q3); + float32x4_t _q10 = vaddq_f32(_q4, _q5); + float32x4_t _q11 = vaddq_f32(_q6, _q7); + float32x4_t _q12 = vsubq_f32(_q2, _q3); + float32x4_t _q13 = vsubq_f32(_q4, _q5); + float32x4_t _q14 = vsubq_f32(_q6, _q7); + _q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + _q1 = vaddq_f32(_q1, _q9); + _q1 = vaddq_f32(_q1, _q10); + _q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1); + + _q2 = vaddq_f32(_q12, _q6); + _q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1); + + _q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0); + + _q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1); + + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + vst1_f32(out_ptr0, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0])); + vst1_f32(out_ptr1, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1])); + vst1_f32(out_ptr2, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0])); + vst1_f32(out_ptr3, vget_high_f32(_q23.val[1])); + vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1])); + + _q1 = vaddq_f32(_q9, _q7); + _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q8); + _q2 = vaddq_f32(_q2, _q14); + _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); + _q23 = vtrnq_f32(_q1, _q2); + vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1])); + + // remain 2 rows + _q1 = vld1q_f32(at_m_ptr0 + 16); + _q2 = vld1q_f32(at_m_ptr0 + 20); + _q3 = vld1q_f32(at_m_ptr1 + 16); + _q4 = vld1q_f32(at_m_ptr1 + 20); + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + + float32x2_t _d2 = vget_low_f32(_q23.val[0]); + float32x2_t _d3 = vget_high_f32(_q23.val[0]); + float32x2_t _d4 = vget_low_f32(_q23.val[1]); + float32x2_t _d5 = vget_high_f32(_q23.val[1]); + float32x2_t _d6 = vget_low_f32(_q45.val[0]); + float32x2_t _d7 = vget_high_f32(_q45.val[0]); + float32x2_t _d8 = vget_low_f32(_q45.val[1]); + float32x2_t _d9 = vget_high_f32(_q45.val[1]); + + float32x2_t _d10 = vadd_f32(_d4, _d3); + float32x2_t _d11 = vadd_f32(_d5, _d6); + float32x2_t _d12 = vadd_f32(_d8, _d7); + float32x2_t _d13 = vsub_f32(_d4, _d3); + float32x2_t _d14 = vsub_f32(_d5, _d6); + float32x2_t _d15 = vsub_f32(_d8, _d7); + float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0); + float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0); + + float32x2_t _d18 = vadd_f32(_d2, _d10); + float32x2_t _d20 = vadd_f32(_d13, _d16); + float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1); + float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0); + _d18 = vadd_f32(_d18, _d11); + _d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1); + _d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1); + _d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0); + _d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1); + + float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20); + float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21); + vst1_f32(out_ptr4, _d18d20.val[0]); + vst1_f32(out_ptr4 + 2, _d19d21.val[0]); + vst1_f32(out_ptr5, _d18d20.val[1]); + vst1_f32(out_ptr5 + 2, _d19d21.val[1]); + + _d18 = vadd_f32(_d10, _d17); + _d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1); + _d20 = vadd_f32(_d13, _d9); + _d20 = vadd_f32(_d20, _d15); + _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); + _d18d20 = vtrn_f32(_d18, _d20); + vst1_f32(out_ptr4 + 4, _d18d20.val[0]); + vst1_f32(out_ptr5 + 4, _d18d20.val[1]); +#else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" // process 4 rows @@ -1204,6 +1407,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w; float *out_ptr = output_ptr + offset; int remain_row = out_h - 6 * tile_h; @@ -1221,6 +1425,131 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, float *out_ptr3 = out_ptr2 + out_w; float *out_ptr4 = out_ptr3 + out_w; float *out_ptr5 = out_ptr4 + out_w; +#if __aarch64__ + float32x4_t _q0 = vld1q_f32(transform_matrix); + float32x4x2_t _q23, _q45, _q67, _q89; + _q23.val[0] = vld1q_f32(at_m_ptr0); + _q23.val[1] = vld1q_f32(at_m_ptr0 + 4); + _q45.val[0] = vld1q_f32(at_m_ptr0 + 8); + _q45.val[1] = vld1q_f32(at_m_ptr0 + 12); + _q67.val[0] = vld1q_f32(at_m_ptr1); + _q67.val[1] = vld1q_f32(at_m_ptr1 + 4); + _q89.val[0] = vld1q_f32(at_m_ptr1 + 8); + _q89.val[1] = vld1q_f32(at_m_ptr1 + 12); + _q23 = vtrnq_f32(_q23.val[0], _q23.val[1]); + _q45 = vtrnq_f32(_q45.val[0], _q45.val[1]); + _q67 = vtrnq_f32(_q67.val[0], _q67.val[1]); + _q89 = vtrnq_f32(_q89.val[0], _q89.val[1]); + float32x4_t _q1 = vcombine_f32(vget_low_f32(_q23.val[0]), + vget_low_f32(_q45.val[0])); + float32x4_t _q3 = vcombine_f32(vget_high_f32(_q23.val[0]), + vget_high_f32(_q45.val[0])); + float32x4_t _q2 = vcombine_f32(vget_low_f32(_q23.val[1]), + vget_low_f32(_q45.val[1])); + float32x4_t _q4 = vcombine_f32(vget_high_f32(_q23.val[1]), + vget_high_f32(_q45.val[1])); + float32x4_t _q5 = vcombine_f32(vget_low_f32(_q67.val[0]), + vget_low_f32(_q89.val[0])); + float32x4_t _q7 = vcombine_f32(vget_high_f32(_q67.val[0]), + vget_high_f32(_q89.val[0])); + float32x4_t _q6 = vcombine_f32(vget_low_f32(_q67.val[1]), + vget_low_f32(_q89.val[1])); + float32x4_t _q8 = vcombine_f32(vget_high_f32(_q67.val[1]), + vget_high_f32(_q89.val[1])); + + float32x4_t _q9 = vaddq_f32(_q2, _q3); + float32x4_t _q10 = vaddq_f32(_q4, _q5); + float32x4_t _q11 = vaddq_f32(_q6, _q7); + float32x4_t _q12 = vsubq_f32(_q2, _q3); + float32x4_t _q13 = vsubq_f32(_q4, _q5); + float32x4_t _q14 = vsubq_f32(_q6, _q7); + _q6 = vmulq_lane_f32(_q13, vget_low_f32(_q0), 0); + _q7 = vmulq_lane_f32(_q11, vget_low_f32(_q0), 0); + + _q1 = vaddq_f32(_q1, _q9); + _q1 = vaddq_f32(_q1, _q10); + _q1 = vmlaq_lane_f32(_q1, _q7, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q6); + _q2 = vmlaq_lane_f32(_q2, _q14, vget_high_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q9, _q10, vget_low_f32(_q0), 1); + _q3 = vmlaq_lane_f32(_q3, _q11, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q12, _q13, vget_high_f32(_q0), 0); + _q4 = vmlaq_lane_f32(_q4, _q14, vget_low_f32(_q0), 1); + + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + vst1_f32(out_ptr0, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr0 + 2, vget_low_f32(_q45.val[0])); + vst1_f32(out_ptr1, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr1 + 2, vget_low_f32(_q45.val[1])); + vst1_f32(out_ptr2, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr2 + 2, vget_high_f32(_q45.val[0])); + vst1_f32(out_ptr3, vget_high_f32(_q23.val[1])); + vst1_f32(out_ptr3 + 2, vget_high_f32(_q45.val[1])); + + _q1 = vaddq_f32(_q9, _q7); + _q1 = vmlaq_lane_f32(_q1, _q10, vget_high_f32(_q0), 1); + _q2 = vaddq_f32(_q12, _q8); + _q2 = vaddq_f32(_q2, _q14); + _q2 = vmlaq_lane_f32(_q2, _q6, vget_high_f32(_q0), 1); + _q23 = vtrnq_f32(_q1, _q2); + vst1_f32(out_ptr0 + 4, vget_low_f32(_q23.val[0])); + vst1_f32(out_ptr1 + 4, vget_low_f32(_q23.val[1])); + vst1_f32(out_ptr2 + 4, vget_high_f32(_q23.val[0])); + vst1_f32(out_ptr3 + 4, vget_high_f32(_q23.val[1])); + + // remain 2 rows + _q1 = vld1q_f32(at_m_ptr0 + 16); + _q2 = vld1q_f32(at_m_ptr0 + 20); + _q3 = vld1q_f32(at_m_ptr1 + 16); + _q4 = vld1q_f32(at_m_ptr1 + 20); + _q23 = vtrnq_f32(_q1, _q2); + _q45 = vtrnq_f32(_q3, _q4); + + float32x2_t _d2 = vget_low_f32(_q23.val[0]); + float32x2_t _d3 = vget_high_f32(_q23.val[0]); + float32x2_t _d4 = vget_low_f32(_q23.val[1]); + float32x2_t _d5 = vget_high_f32(_q23.val[1]); + float32x2_t _d6 = vget_low_f32(_q45.val[0]); + float32x2_t _d7 = vget_high_f32(_q45.val[0]); + float32x2_t _d8 = vget_low_f32(_q45.val[1]); + float32x2_t _d9 = vget_high_f32(_q45.val[1]); + + float32x2_t _d10 = vadd_f32(_d4, _d3); + float32x2_t _d11 = vadd_f32(_d5, _d6); + float32x2_t _d12 = vadd_f32(_d8, _d7); + float32x2_t _d13 = vsub_f32(_d4, _d3); + float32x2_t _d14 = vsub_f32(_d5, _d6); + float32x2_t _d15 = vsub_f32(_d8, _d7); + float32x2_t _d16 = vmul_lane_f32(_d14, vget_low_f32(_q0), 0); + float32x2_t _d17 = vmul_lane_f32(_d12, vget_low_f32(_q0), 0); + + float32x2_t _d18 = vadd_f32(_d2, _d10); + float32x2_t _d20 = vadd_f32(_d13, _d16); + float32x2_t _d19 = vmla_lane_f32(_d10, _d11, vget_low_f32(_q0), 1); + float32x2_t _d21 = vmla_lane_f32(_d13, _d14, vget_high_f32(_q0), 0); + _d18 = vadd_f32(_d18, _d11); + _d18 = vmla_lane_f32(_d18, _d17, vget_high_f32(_q0), 1); + _d20 = vmla_lane_f32(_d20, _d15, vget_high_f32(_q0), 1); + _d19 = vmla_lane_f32(_d19, _d12, vget_high_f32(_q0), 0); + _d21 = vmla_lane_f32(_d21, _d15, vget_low_f32(_q0), 1); + + float32x2x2_t _d18d20 = vtrn_f32(_d18, _d20); + float32x2x2_t _d19d21 = vtrn_f32(_d19, _d21); + vst1_f32(out_ptr4, _d18d20.val[0]); + vst1_f32(out_ptr4 + 2, _d19d21.val[0]); + vst1_f32(out_ptr5, _d18d20.val[1]); + vst1_f32(out_ptr5 + 2, _d19d21.val[1]); + + _d18 = vadd_f32(_d10, _d17); + _d18 = vmla_lane_f32(_d18, _d11, vget_high_f32(_q0), 1); + _d20 = vadd_f32(_d13, _d9); + _d20 = vadd_f32(_d20, _d15); + _d20 = vmla_lane_f32(_d20, _d16, vget_high_f32(_q0), 1); + _d18d20 = vtrn_f32(_d18, _d20); + vst1_f32(out_ptr4 + 4, _d18d20.val[0]); + vst1_f32(out_ptr5 + 4, _d18d20.val[1]); +#else asm volatile( "vld1.32 {d0-d1}, [%[tm_ptr]] \n" // process 4 rows @@ -1337,6 +1666,7 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, : [tm_ptr] "r"((float *)transform_matrix) : "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15"); +#endif // __aarch64__ } } } @@ -1347,5 +1677,5 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input, } // namespace operators } // namespace paddle_mobile -#endif // __aarch64__ #endif // CONV_OP +#endif // __ARM_NEON__ diff --git a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp b/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp deleted file mode 100644 index 5ef9c194f23ba28791f673137313aacc262d39d5..0000000000000000000000000000000000000000 --- a/src/operators/math/winograd/winograd_transform_f6k3_arm64.cpp +++ /dev/null @@ -1,413 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -// We refer https://github.com/andravin/wincnn to access the winograd transform -// matrixs - -#ifdef CONV_OP -#ifdef __aarch64__ - -#include "operators/math/winograd/winograd_transform.h" - -namespace paddle_mobile { -namespace operators { -namespace math { - -template <> -void winograd_transform_weight<8, 3>(const framework::Tensor &weight, - framework::Tensor *output) { - // weight shape is [out_channel, in_channel, kernel_h, kernel_w] - int out_channel = weight.dims()[0]; - int in_channel = weight.dims()[1]; - // reshape and alloc transformed weight - framework::DDim transformed_shape = - framework::make_ddim(std::vector{out_channel, in_channel, 64}); - float *outptr = output->mutable_data(transformed_shape); - const float *inptr = weight.data(); - for (int oc = 0; oc < out_channel; ++oc) { - for (int ic = 0; ic < in_channel; ++ic) { - size_t offset = oc * in_channel + ic; - float *kout = outptr + offset * 64; - const float *k = inptr + offset * 9; - - float gw[3][8]; - for (int i = 0; i < 3; ++i, k += 3) { - float g0 = k[0]; - float g1 = k[1]; - float g2 = k[2]; - float d0 = g0 + g2; - float d1 = g0 + 4 * g2; - float d2 = g2 + 4 * g0; - float d3 = 2 * g1; - gw[i][0] = g0; - gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2) - gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2) - gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2) - gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2) - gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2) - gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2) - gw[i][7] = g2; - } - for (int i = 0; i < 8; ++i, kout += 8) { - float g0 = gw[0][i]; - float g1 = gw[1][i]; - float g2 = gw[2][i]; - float d0 = g0 + g2; - float d1 = g0 + 4 * g2; - float d2 = g2 + 4 * g0; - float d3 = 2 * g1; - kout[0] = g0; - kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2) - kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2) - kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2) - kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2) - kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2) - kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2) - kout[7] = g2; - } - } - } -} - -template <> -void winograd_transform_input<8, 3>(const framework::Tensor &input, - framework::Tensor *output) { - // tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation - int channel = input.dims()[1]; - int height = input.dims()[2]; - int width = input.dims()[3]; - int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6 - int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6 - framework::DDim transformed_shape = - framework::make_ddim(std::vector{channel, h_tiles, w_tiles, 64}); - float *outptr = output->mutable_data(transformed_shape); - memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float)); - const float *inptr = input.data(); - // pack input to tiles - for (int c = 0; c < channel; ++c) { - int inter_h = (height - 2) / 6; - int inter_w = (width - 2) / 6; - int remain_h = height - (inter_h * 6); - int remain_w = width - (inter_w * 6); - const float *in0 = inptr + c * height * width; - const float *in1 = in0 + width; - const float *in2 = in1 + width; - const float *in3 = in2 + width; - const float *in4 = in3 + width; - const float *in5 = in4 + width; - const float *in6 = in5 + width; - const float *in7 = in6 + width; - float *out = outptr + c * h_tiles * w_tiles * 64; - - for (int h = 0; h < inter_h; ++h) { - for (int w = 0; w < inter_w; ++w) { - memcpy(out, in0, 8 * sizeof(float)); - memcpy(out + 8, in1, 8 * sizeof(float)); - memcpy(out + 16, in2, 8 * sizeof(float)); - memcpy(out + 24, in3, 8 * sizeof(float)); - memcpy(out + 32, in4, 8 * sizeof(float)); - memcpy(out + 40, in5, 8 * sizeof(float)); - memcpy(out + 48, in6, 8 * sizeof(float)); - memcpy(out + 56, in7, 8 * sizeof(float)); - in0 += 6; - in1 += 6; - in2 += 6; - in3 += 6; - in4 += 6; - in5 += 6; - in6 += 6; - in7 += 6; - out += 64; - } - // remain width - if (remain_w > 2) { - memcpy(out, in0, remain_w * sizeof(float)); - memcpy(out + 8, in1, remain_w * sizeof(float)); - memcpy(out + 16, in2, remain_w * sizeof(float)); - memcpy(out + 24, in3, remain_w * sizeof(float)); - memcpy(out + 32, in4, remain_w * sizeof(float)); - memcpy(out + 40, in5, remain_w * sizeof(float)); - memcpy(out + 48, in6, remain_w * sizeof(float)); - memcpy(out + 56, in7, remain_w * sizeof(float)); - out += 64; - } - in0 += 5 * width + remain_w; - in1 += 5 * width + remain_w; - in2 += 5 * width + remain_w; - in3 += 5 * width + remain_w; - in4 += 5 * width + remain_w; - in5 += 5 * width + remain_w; - in6 += 5 * width + remain_w; - in7 += 5 * width + remain_w; - } - // remain height - if (remain_h > 2) { - for (int w = 0; w < inter_w; ++w) { - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float)); - } - out += 64; - in0 += 6; - } - // remain width - if (remain_w > 2) { - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float)); - } - } - } - } - // transform tiles, compute B_T * d(c, b) * B - for (int c = 0; c < channel; ++c) { - for (int tile = 0; tile < h_tiles * w_tiles; ++tile) { - float *out = outptr + (c * h_tiles * w_tiles + tile) * 64; - // compute B_T * d(c, b) - float bd[8][8]; - for (int i = 0; i < 8; ++i) { - float d0 = out[8 * i + 0]; - float d1 = out[8 * i + 1]; - float d2 = out[8 * i + 2]; - float d3 = out[8 * i + 3]; - float d4 = out[8 * i + 4]; - float d5 = out[8 * i + 5]; - float d6 = out[8 * i + 6]; - float d7 = out[8 * i + 7]; - - bd[i][0] = d0 - d6 + (d4 - d2) * 5.25; - float v1 = d2 - 4.25 * d4 + d6; - float v2 = d1 - 4.25 * d3 + d5; - // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 - bd[i][1] = v1 + v2; - // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 - bd[i][2] = v1 - v2; - v1 = 0.25 * d2 - 1.25 * d4 + d6; - v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; - // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 - bd[i][3] = v1 + v2; - // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 - bd[i][4] = v1 - v2; - v1 = 4 * d2 - 5 * d4 + d6; - v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; - // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 - bd[i][5] = v1 + v2; - // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 - bd[i][6] = v1 - v2; - bd[i][7] = d7 - d1 + (d3 - d5) * 5.25; - } - // compute B_T * d(c, b) * B - for (int i = 0; i < 8; ++i, out += 8) { - float d0 = bd[0][i]; - float d1 = bd[1][i]; - float d2 = bd[2][i]; - float d3 = bd[3][i]; - float d4 = bd[4][i]; - float d5 = bd[5][i]; - float d6 = bd[6][i]; - float d7 = bd[7][i]; - - out[0] = d0 - d6 + (d4 - d2) * 5.25; - float v1 = d2 - 4.25 * d4 + d6; - float v2 = d1 - 4.25 * d3 + d5; - // d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6 - out[1] = v1 + v2; - // -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6 - out[2] = v1 - v2; - v1 = 0.25 * d2 - 1.25 * d4 + d6; - v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5; - // 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6 - out[3] = v1 + v2; - // -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6 - out[4] = v1 - v2; - v1 = 4 * d2 - 5 * d4 + d6; - v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5; - // 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6 - out[5] = v1 + v2; - // -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6 - out[6] = v1 - v2; - out[7] = d7 - d1 + (d3 - d5) * 5.25; - } - } - } -} - -template <> -void winograd_transform_output<8, 3>(const framework::Tensor &input, - const framework::Tensor &weight, - framework::Tensor *output) { - // input shape is [in_channel, h_tiles, w_tiles, 64] - // weight shape is [out_channel, in_channel, 64] - int in_channel = input.dims()[0]; - int h_tiles = input.dims()[1]; - int w_tiles = input.dims()[2]; - int tiles = h_tiles * w_tiles; - int out_channel = weight.dims()[0]; - // compute U*V first - framework::Tensor output_m; - framework::DDim shape = - framework::make_ddim(std::vector{out_channel, tiles, 64}); - float *output_m_ptr = output_m.mutable_data(shape); - memset(output_m_ptr, 0, output_m.numel() * sizeof(float)); - const float *input_ptr = input.data(); - const float *weight_ptr = weight.data(); - for (int i = 0; i < out_channel; ++i) { - for (int j = 0; j < tiles; ++j) { - const float *w_ptr = weight_ptr + i * in_channel * 64; - const float *in_ptr = input_ptr + j * 64; - float *m_ptr = output_m_ptr + (i * tiles + j) * 64; - for (int c = 0; c < in_channel; ++c) { - for (int k = 0; k < 64; ++k) { - m_ptr[k] += w_ptr[k] * in_ptr[k]; - } - w_ptr += 64; - in_ptr += tiles * 64; - } - } - } - - for (int oc = 0; oc < out_channel; ++oc) { - for (int tile = 0; tile < tiles; ++tile) { - float *m = output_m_ptr + (oc * tiles + tile) * 64; - // compute A_T * m - float am[6][8]; - for (int i = 0; i < 8; ++i) { - float d0 = m[i * 8 + 0]; - float d1 = m[i * 8 + 1]; - float d2 = m[i * 8 + 2]; - float d3 = m[i * 8 + 3]; - float d4 = m[i * 8 + 4]; - float d5 = m[i * 8 + 5]; - float d6 = m[i * 8 + 6]; - float d7 = m[i * 8 + 7]; - float v0 = d1 + d2; - float v1 = d1 - d2; - float v2 = d3 + d4; - float v3 = d3 - d4; - float v4 = d5 + d6; - float v5 = d5 - d6; - - am[0][i] = d0 + v0 + v2 + 32 * v4; - am[1][i] = v1 + 2 * v3 + 16 * v5; - am[2][i] = v0 + 4 * v2 + 8 * v4; - am[3][i] = v1 + 8 * v3 + 4 * v5; - am[4][i] = v0 + 16 * v2 + 2 * v4; - am[5][i] = v1 + 32 * v3 + v5 + d7; - } - // compute A_T * m * A - for (int i = 0; i < 6; ++i, m += 8) { - float d0 = am[i][0]; - float d1 = am[i][1]; - float d2 = am[i][2]; - float d3 = am[i][3]; - float d4 = am[i][4]; - float d5 = am[i][5]; - float d6 = am[i][6]; - float d7 = am[i][7]; - float v0 = d1 + d2; - float v1 = d1 - d2; - float v2 = d3 + d4; - float v3 = d3 - d4; - float v4 = d5 + d6; - float v5 = d5 - d6; - - m[0] = d0 + v0 + v2 + 32 * v4; - m[1] = v1 + 2 * v3 + 16 * v5; - m[2] = v0 + 4 * v2 + 8 * v4; - m[3] = v1 + 8 * v3 + 4 * v5; - m[4] = v0 + 16 * v2 + 2 * v4; - m[5] = v1 + 32 * v3 + v5 + d7; - } - } - } - - int out_h = output->dims()[2]; - int out_w = output->dims()[3]; - float *output_ptr = output->mutable_data(); - // copy valid region to final output - for (int oc = 0; oc < out_channel; ++oc) { - int inter_h = out_h / 6; - int inter_w = out_w / 6; - int remain_h = out_h - inter_h * 6; - int remain_w = out_w - inter_w * 6; - - float *out_ptr0 = output_ptr + oc * out_h * out_w; - float *out_ptr1 = out_ptr0 + out_w; - float *out_ptr2 = out_ptr1 + out_w; - float *out_ptr3 = out_ptr2 + out_w; - float *out_ptr4 = out_ptr3 + out_w; - float *out_ptr5 = out_ptr4 + out_w; - const float *m_ptr = output_m_ptr + oc * tiles * 64; - for (int tile_h = 0; tile_h < inter_h; ++tile_h) { - for (int tile_w = 0; tile_w < inter_w; ++tile_w) { - const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64; - memcpy(out_ptr0, m, 6 * sizeof(float)); - memcpy(out_ptr1, m + 8, 6 * sizeof(float)); - memcpy(out_ptr2, m + 16, 6 * sizeof(float)); - memcpy(out_ptr3, m + 24, 6 * sizeof(float)); - memcpy(out_ptr4, m + 32, 6 * sizeof(float)); - memcpy(out_ptr5, m + 40, 6 * sizeof(float)); - out_ptr0 += 6; - out_ptr1 += 6; - out_ptr2 += 6; - out_ptr3 += 6; - out_ptr4 += 6; - out_ptr5 += 6; - } - // remain w - if (remain_w > 0) { - const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64; - memcpy(out_ptr0, m, remain_w * sizeof(float)); - memcpy(out_ptr1, m + 8, remain_w * sizeof(float)); - memcpy(out_ptr2, m + 16, remain_w * sizeof(float)); - memcpy(out_ptr3, m + 24, remain_w * sizeof(float)); - memcpy(out_ptr4, m + 32, remain_w * sizeof(float)); - memcpy(out_ptr5, m + 40, remain_w * sizeof(float)); - out_ptr0 += remain_w; - out_ptr1 += remain_w; - out_ptr2 += remain_w; - out_ptr3 += remain_w; - out_ptr4 += remain_w; - out_ptr5 += remain_w; - } - out_ptr0 += 5 * out_w; - out_ptr1 += 5 * out_w; - out_ptr2 += 5 * out_w; - out_ptr3 += 5 * out_w; - out_ptr4 += 5 * out_w; - out_ptr5 += 5 * out_w; - } - // remain h - if (remain_h > 0) { - for (int tile_w = 0; tile_w < inter_w; ++tile_w) { - const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64; - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float)); - } - out_ptr0 += 6; - } - if (remain_w > 0) { - const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64; - for (int rh = 0; rh < remain_h; ++rh) { - memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float)); - } - } - } - } -} - -} // namespace math -} // namespace operators -} // namespace paddle_mobile - -#endif // __aarch64__ -#endif // CONV_OP diff --git a/src/operators/mul_op.h b/src/operators/mul_op.h index 51e828202e8da2080f014eff2bd60472dd873884..b08cdbf99191df63221df67135fea584ad62f514 100644 --- a/src/operators/mul_op.h +++ b/src/operators/mul_op.h @@ -31,7 +31,7 @@ class MulOp : public framework::OperatorWithKernel< public: MulOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::MulKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/multiclass_nms_op.h b/src/operators/multiclass_nms_op.h index 059974ab214004bcd1423514c85353da9a9bb6b8..bba701d81a26e348ffcca12be15a017a2edbdb1e 100644 --- a/src/operators/multiclass_nms_op.h +++ b/src/operators/multiclass_nms_op.h @@ -34,8 +34,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel< public: MultiClassNMSOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, MultiClassNMSParam, operators::MultiClassNMSKernel>( diff --git a/src/operators/norm_op.h b/src/operators/norm_op.h index 5bd6924af1d4aca125795be879cf67c09832965e..64d8e7c3ccebb66a650eac9e6b67e4dacbde7994 100644 --- a/src/operators/norm_op.h +++ b/src/operators/norm_op.h @@ -31,7 +31,7 @@ class NormOp public: NormOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, NormKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/one_hot_op.cpp b/src/operators/one_hot_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..396f55318a80dc3177a0ae5f4b151eaec7806a6d --- /dev/null +++ b/src/operators/one_hot_op.cpp @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ONE_HOT_OP + +#pragma once + +#include "operators/one_hot_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void OnehotOp::InferShape() const { + const auto &x_dims = this->param_.input_->dims(); + int depth = this->param_.depth_; + framework::DDim out_dims(x_dims); + out_dims[out_dims.size() - 1] = depth; + this->param_.output_->Resize(out_dims); + this->param_.output_->set_lod(this->param_.input_->lod()); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(one_hot, ops::OnehotOp); +#endif + +#endif // ONE_HOT_OP diff --git a/src/operators/one_hot_op.h b/src/operators/one_hot_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4b7e83bf996873de087844887839055031d97f66 --- /dev/null +++ b/src/operators/one_hot_op.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ONE_HOT_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/one_hot_kernel.h" + +namespace paddle_mobile { +namespace operators { + +DECLARE_OPERATOR(Onehot, OnehotParam, OnehotKernel); + +} // namespace operators +} // namespace paddle_mobile + +#endif // ONE_HOT_OP diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 99822947671423d483d61294edcbc825021f1ad0..95cb7d675370ca2afc5fae21311ed2bbbb27cc0c 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -453,11 +453,11 @@ class ConvParam : public OpParam { groups = OpParam::GetAttr("groups", attrs); } - const RType *Input() const { return input_; } + const GType *Input() const { return input_; } - RType *Filter() const { return filter_; } + GType *Filter() const { return filter_; } - RType *Output() const { return output_; } + GType *Output() const { return output_; } const vector &Strides() const { return strides_; } @@ -468,10 +468,8 @@ class ConvParam : public OpParam { enum ExecMode { EXEC_INVALID = 0, EXEC_GEMM_FLOAT, - EXEC_DEPTHWISE3x3S1P1_FLOAT, - EXEC_DEPTHWISE3x3S2P0_FLOAT, - EXEC_DEPTHWISE3x3S2P1_FLOAT, - EXEC_DEPTHWISE3x3_FLOAT, + EXEC_DEPTHWISE3x3S1_FLOAT, + EXEC_DEPTHWISE3x3S2_FLOAT, EXEC_WINOGRAD3X3_FLOAT, EXEC_WINOGRAD5X5_FLOAT, EXEC_DEPTHWISE5x5_FLOAT, @@ -492,10 +490,10 @@ class ConvParam : public OpParam { #endif public: - RType *input_; - RType *output_; - RType *filter_; - RType *transformed_filter_; + GType *input_; + GType *output_; + GType *filter_; + GType *transformed_filter_; vector strides_; vector paddings_; vector dilations_; @@ -772,11 +770,11 @@ class LrnParam : public OpParam { data_format_ = GetStringAttr("data_format", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } - RType *MidOut() const { return mid_out_; } + GType *MidOut() const { return mid_out_; } const int &N() const { return n_; } @@ -789,9 +787,9 @@ class LrnParam : public OpParam { const string &DataFormat() const { return data_format_; } private: - RType *input_x_; - RType *out_; - RType *mid_out_; + GType *input_x_; + GType *out_; + GType *mid_out_; int n_; float alpha_; float beta_; @@ -817,20 +815,20 @@ class NormParam : public OpParam { axis_ = GetAttr("axis", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } - RType *OutputNorm() const { return output_norm_; } + GType *OutputNorm() const { return output_norm_; } const float &Epsilon() const { return epsilon_; } const int &Axis() const { return axis_; } private: - RType *input_x_; - RType *out_; - RType *output_norm_; + GType *input_x_; + GType *out_; + GType *output_norm_; float epsilon_; int axis_; }; @@ -857,17 +855,17 @@ class BatchNormParam : public OpParam { // is_test_ = GetAttr("is_test", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - RType *OutputY() const { return output_y_; } + GType *OutputY() const { return output_y_; } - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - const RType *InputMean() const { return input_mean_; } + const GType *InputMean() const { return input_mean_; } - const RType *InputScale() const { return input_scale_; } + const GType *InputScale() const { return input_scale_; } - const RType *InputVariance() const { return input_variance_; } + const GType *InputVariance() const { return input_variance_; } const float &Epsilon() const { return epsilon_; } @@ -877,27 +875,27 @@ class BatchNormParam : public OpParam { const string &DataFormat() const { return data_format_; } - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } - const RType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_; } - const RType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_; } private: - RType *input_x_; - RType *output_y_; - RType *input_bias_; - RType *input_mean_; - RType *input_scale_; - RType *input_variance_; + GType *input_x_; + GType *output_y_; + GType *input_bias_; + GType *input_mean_; + GType *input_scale_; + GType *input_variance_; float epsilon_; float momentum_; bool is_test_; string data_format_; - RType *new_bias_; - RType *new_scale_; + GType *new_bias_; + GType *new_scale_; }; #endif @@ -922,9 +920,9 @@ class PoolParam : public OpParam { global_pooling_ = GetAttr("global_pooling", attrs); } - const RType *Input() const { return input_; } + const GType *Input() const { return input_; } - RType *Output() const { return output_; } + GType *Output() const { return output_; } const string &PoolingType() const { return pooling_type_; } @@ -939,8 +937,8 @@ class PoolParam : public OpParam { bool isGlobalPooling() const { return global_pooling_; } private: - RType *input_; - RType *output_; + GType *input_; + GType *output_; string pooling_type_; vector ksize_; vector strides_; @@ -990,13 +988,13 @@ class PriorBoxParam : public OpParam { step_h_ = GetAttr("step_h", attrs); offset_ = GetAttr("offset", attrs); } - const RType *Input() const { return input_; } + const GType *Input() const { return input_; } - const RType *InputImage() const { return input_image_; } + const GType *InputImage() const { return input_image_; } - RType *OutputBoxes() const { return output_boxes_; } + GType *OutputBoxes() const { return output_boxes_; } - RType *OutputVariances() const { return output_variances_; } + GType *OutputVariances() const { return output_variances_; } const vector &MinSizes() const { return min_sizes_; } @@ -1021,10 +1019,10 @@ class PriorBoxParam : public OpParam { } private: - RType *input_; - RType *input_image_; - RType *output_boxes_; - RType *output_variances_; + GType *input_; + GType *input_image_; + GType *output_boxes_; + GType *output_variances_; vector min_sizes_; vector max_sizes_; vector aspect_ratios_; @@ -1054,21 +1052,21 @@ class BoxCoderParam : public OpParam { output_box_ = OutputBoxFrom(outputs, *scope); code_type_ = GetStringAttr("code_type", attrs); } - const RType *InputPriorBox() const { return input_priorbox_; } + const GType *InputPriorBox() const { return input_priorbox_; } - const RType *InputPriorBoxVar() const { return input_priorboxvar_; } + const GType *InputPriorBoxVar() const { return input_priorboxvar_; } - const RType *InputTargetBox() const { return input_targetbox_; } + const GType *InputTargetBox() const { return input_targetbox_; } - RType *OutputBox() const { return output_box_; } + GType *OutputBox() const { return output_box_; } const std::string &CodeType() const { return code_type_; } private: - RType *input_priorbox_; - RType *input_priorboxvar_; - RType *input_targetbox_; - RType *output_box_; + GType *input_priorbox_; + GType *input_priorboxvar_; + GType *input_targetbox_; + GType *output_box_; std::string code_type_; }; #endif @@ -1096,14 +1094,14 @@ class SoftmaxParam : public OpParam { #ifdef PADDLE_MOBILE_FPGA private: - std::shared_ptr float_input_x_; + std::shared_ptr float_input_x_; fpga::BypassArgs fpga_bypass_args; public: - RType *FloatInput() const { + GType *FloatInput() const { return float_input_x_ == nullptr ? input_x_ : float_input_x_.get(); } - void SetFloatInput(Tensor *input) { float_input_x_.reset(input); } + void SetFloatInput(LoDTensor *input) { float_input_x_.reset(input); } const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; } void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; } #endif @@ -1123,12 +1121,12 @@ class SigmoidParam : public OpParam { input_x_ = InputXFrom(inputs, *scope); out_ = OutFrom(outputs, *scope); } - const RType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + GType *Out() const { return out_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; #ifdef PADDLE_MOBILE_FPGA private: @@ -1163,11 +1161,11 @@ class MultiClassNMSParam : public OpParam { score_threshold_ = GetAttr("score_threshold", attrs); } - RType *InputBBoxes() const { return input_bboxes_; } + GType *InputBBoxes() const { return input_bboxes_; } - RType *InputScores() const { return input_scores_; } + GType *InputScores() const { return input_scores_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const int &BackGroundLabel() const { return background_label_; } @@ -1182,9 +1180,9 @@ class MultiClassNMSParam : public OpParam { const float &ScoreThreshold() const { return score_threshold_; } private: - RType *input_bboxes_; - RType *input_scores_; - RType *out_; + GType *input_bboxes_; + GType *input_scores_; + GType *out_; int background_label_; int nms_top_k_; int keep_top_k_; @@ -1208,12 +1206,12 @@ class PolygonBoxTransformParam : public OpParam { input_ = InputFrom(inputs, *scope); output_ = OutputFrom(outputs, *scope); } - const RType *Input() const { return input_; } - RType *Output() const { return output_; } + const GType *Input() const { return input_; } + GType *Output() const { return output_; } private: - RType *input_; - RType *output_; + GType *input_; + GType *output_; }; #endif @@ -1226,24 +1224,21 @@ class FeedParam : public OpParam { FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, Scope *scope) : OpParam(inputs, outputs, attrs, scope) { -#ifdef PADDLE_MOBILE_FPGA - static int feed_num = 0; - auto new_name = std::string("feed") + std::to_string(feed_num++); - const_cast(inputs).at("X") = {string(new_name)}; -#endif - - input_x_ = InputXFrom(inputs, *scope); + input_x_ = InputXFrom(inputs, *scope); out_ = OutFrom(outputs, *scope); + col_ = GetAttr("col", attrs); auto var = scope->FindVar("batch_size"); batch_size = var->GetValue(); } - const LoDTensor *InputX() const { return input_x_; } + const framework::LoDTensorArray *InputX() const { return input_x_; } GType *Out() const { return out_; } + const int Col() const { return col_; } const int BatchSize() const { return batch_size; } private: - LoDTensor *input_x_; + framework::LoDTensorArray *input_x_; GType *out_; + int col_; int batch_size; }; @@ -1256,30 +1251,23 @@ class FetchParam : public OpParam { FetchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, Scope *scope) : OpParam(inputs, outputs, attrs, scope) { -#ifdef PADDLE_MOBILE_FPGA - static int fetch_num = 0; - auto new_name = std::string("fetch") + std::to_string(fetch_num++); - const_cast(outputs).at("Out") = {string(new_name)}; -#endif - input_x_ = InputXFrom(inputs, *scope); - out_ = OutFrom(outputs, *scope); + input_x_ = InputXFrom(inputs, *scope); + out_ = OutFrom(outputs, *scope); + col_ = GetAttr("col", attrs); } - const RType *InputX() const { return input_x_; } - Tensor *Out() const { return out_; } - - static Tensor *OutFrom(const VariableNameMap &outputs, const Scope &scope) { - return GetVarValue("Out", outputs, scope); - } + const framework::LoDTensor *InputX() const { return input_x_; } + framework::LoDTensorArray *Out() const { return out_; } + const int Col() const { return col_; } private: - RType *input_x_; - Tensor *out_; + framework::LoDTensor *input_x_; + framework::LoDTensorArray *out_; + int col_; #ifdef PADDLE_MOBILE_FPGA public: fpga::BypassArgs fpga_bypass_args; - #endif }; @@ -1303,7 +1291,7 @@ class FillConstantParam : public OpParam { Variable *OutVar() const { return out_var_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const int &DataDtype() const { return dtype_; } @@ -1313,7 +1301,7 @@ class FillConstantParam : public OpParam { private: Variable *out_var_; - RType *out_; + GType *out_; int dtype_; vector shape_; float value_; @@ -1335,15 +1323,15 @@ class TransposeParam : public OpParam { axis_ = GetAttr>("axis", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const vector &Axis() const { return axis_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; vector axis_; }; #endif @@ -1364,18 +1352,18 @@ class Transpose2Param : public OpParam { axis_ = GetAttr>("axis", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } - RType *OutputXShape() const { return output_xshape_; } + GType *OutputXShape() const { return output_xshape_; } const vector &Axis() const { return axis_; } private: - RType *input_x_; - RType *out_; - RType *output_xshape_; + GType *input_x_; + GType *out_; + GType *output_xshape_; vector axis_; }; #endif @@ -1432,8 +1420,8 @@ class CrfParam : public OpParam { const GType *InputTransition() const { return input_transition_; } const GType *InputLabel() const { return input_label_; } GType *outputVBP() const { return output_viterbipath_; } - // const RType *InputIds() const { return input_ids_; } - // RType *Out() const { return out_; } + // const GType *InputIds() const { return input_ids_; } + // GType *Out() const { return out_; } // int64_t PaddingIdx() const { return padding_idx_; } private: @@ -1442,8 +1430,8 @@ class CrfParam : public OpParam { GType *input_label_; GType *output_viterbipath_; - // RType *input_ids_; - // RType *out_; + // GType *input_ids_; + // GType *out_; // int64_t padding_idx_; }; #endif @@ -1471,20 +1459,20 @@ class ReshapeParam : public OpParam { } } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - const RType *InputShape() const { return input_shape_; } + const GType *InputShape() const { return input_shape_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const vector &Shape() const { return shape_; } const bool &Inplace() const { return inplace_; } private: - RType *input_x_; - RType *input_shape_; - RType *out_; + GType *input_x_; + GType *input_shape_; + GType *out_; vector shape_; bool inplace_; }; @@ -1553,11 +1541,11 @@ class ScaleParam : public OpParam { biases_ = GetAttr>("biases", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const bool &Inplace() const { return inplace_; } @@ -1568,9 +1556,9 @@ class ScaleParam : public OpParam { const vector &Biases() const { return biases_; } private: - RType *input_x_; - RType *input_bias_; - RType *out_; + GType *input_x_; + GType *input_bias_; + GType *out_; bool inplace_; bool has_bias_; vector scales_; @@ -1625,11 +1613,11 @@ class ResizeParam : public OpParam { out_width_scale_ = GetAttr("out_width_scale", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - const RType *InputShape() const { return input_shape_; } + const GType *InputShape() const { return input_shape_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const bool &IsPyramidTest() const { return is_pyramid_test_; } @@ -1642,9 +1630,9 @@ class ResizeParam : public OpParam { const float &OutWidthScale() const { return out_width_scale_; } private: - RType *input_x_; - RType *input_shape_; - RType *out_; + GType *input_x_; + GType *input_shape_; + GType *out_; bool is_pyramid_test_; int height_; int width_; @@ -1670,13 +1658,13 @@ class ReluParamBase : public OpParam { out_ = OutFrom(outputs, *scope); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; }; template @@ -1712,23 +1700,23 @@ class TanhParam : public OpParam { input_x_ = InputXFrom(inputs, *scope); out_ = OutFrom(outputs, *scope); } - const RType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + GType *Out() const { return out_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; #ifdef PADDLE_MOBILE_FPGA private: - std::shared_ptr float_input_x_; + std::shared_ptr float_input_x_; fpga::BypassArgs fpga_bypass_args; public: - RType *FloatInput() const { + GType *FloatInput() const { return float_input_x_ == nullptr ? input_x_ : float_input_x_.get(); } - void SetFloatInput(Tensor *input) { float_input_x_.reset(input); } + void SetFloatInput(LoDTensor *input) { float_input_x_.reset(input); } const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; } void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; } #endif @@ -1753,15 +1741,15 @@ class PReluParam : public OpParam { mode_ = GetStringAttr("mode", attrs); DLOG << "PReluParam mode after" << mode_; } - const RType *InputX() const { return input_x_; } - const RType *InputAlpha() const { return alpha_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + const GType *InputAlpha() const { return alpha_; } + GType *Out() const { return out_; } const std::string &Mode() const { return mode_; } private: - RType *input_x_; - RType *out_; - RType *alpha_; + GType *input_x_; + GType *out_; + GType *alpha_; std::string mode_; }; #endif @@ -1785,9 +1773,9 @@ class FusionFcParam : public OpParam { } GType *InputX() const { return input_x_; } - RType *InputY() const { return input_y_; } + GType *InputY() const { return input_y_; } - RType *InputZ() const { return input_z_; } + GType *InputZ() const { return input_z_; } GType *Out() const { return out_; } @@ -1799,8 +1787,8 @@ class FusionFcParam : public OpParam { private: GType *input_x_; - RType *input_y_; - RType *input_z_; + GType *input_y_; + GType *input_z_; GType *out_; int x_num_col_dims_; int y_num_col_dims_; @@ -1833,18 +1821,15 @@ class FusionConvAddParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, *scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, *scope); + this->output_ = OpParam::OutFrom(outputs, *scope); } - RType *Bias() const { return bias_; } + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - RType *Output() const { return output_; } - protected: - RType *bias_; + GType *bias_; int axis_; - RType *output_; }; template @@ -1877,19 +1862,17 @@ class FusionConvAddPReluParam : public ConvParam { framework::DDim dims = alpha_->dims(); bias_ = OpParam::InputYFrom(inputs, *scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, *scope); + this->output_ = OpParam::OutFrom(outputs, *scope); } - const RType *InputAlpha() const { return alpha_; } + const GType *InputAlpha() const { return alpha_; } const std::string &Mode() const { return mode_; } - RType *Bias() const { return bias_; } + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - RType *Output() const { return output_; } protected: - RType *bias_; + GType *bias_; int axis_; - RType *output_; - RType *alpha_; + GType *alpha_; std::string mode_; }; #endif @@ -1910,7 +1893,6 @@ class FusionConvAddAddPReluParam : public ConvParam { mode_ = OpParam::GetStringAttr("mode", attrs); framework::DDim dims = alpha_->dims(); bias_ = OpParam::InputYFrom(inputs, *scope); - output_ = OpParam::OutFrom(outputs, *scope); axis_ = OpParam::GetAttr("axis", attrs); keyOutput_ = OpParam::Getkey("addOut", inputs, 0); keyX1_ = OpParam::Getkey("addX", inputs, 1); @@ -1920,23 +1902,22 @@ class FusionConvAddAddPReluParam : public ConvParam { } else if (keyY1_ == keyOutput_) { bias1_ = OpParam::InputXFrom1(inputs, *scope); } + this->output_ = OpParam::OutFrom(outputs, *scope); } - const RType *InputAlpha() const { return alpha_; } + const GType *InputAlpha() const { return alpha_; } const std::string &Mode() const { return mode_; } - const RType *Bias1() const { return bias1_; } + const GType *Bias1() const { return bias1_; } - RType *Bias() const { return bias_; } + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - RType *Output() const { return output_; } protected: - RType *bias_; + GType *bias_; int axis_; - RType *output_; - RType *alpha_; + GType *alpha_; std::string mode_; - RType *bias1_; + GType *bias1_; std::string keyOutput_; std::string keyX1_; std::string keyY1_; @@ -1956,56 +1937,49 @@ class FusionConvAddBNReluParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, *scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, *scope); input_bias_ = OpParam::InputBiasFrom(inputs, *scope); input_mean_ = OpParam::InputMeanFrom(inputs, *scope); input_scale_ = OpParam::InputScaleFrom(inputs, *scope); input_variance_ = OpParam::InputVarianceFrom(inputs, *scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, *scope); } - RType *Bias() const { return bias_; } + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - RType *Output() const { return output_; } - - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - const RType *InputMean() const { return input_mean_; } + const GType *InputMean() const { return input_mean_; } - const RType *InputScale() const { return input_scale_; } + const GType *InputScale() const { return input_scale_; } - const RType *InputVariance() const { return input_variance_; } + const GType *InputVariance() const { return input_variance_; } const float &Epsilon() const { return epsilon_; } const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } + void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } - - const RType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_; } - const RType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_; } protected: - RType *bias_; + GType *bias_; int axis_; - RType *output_; - RType *input_bias_; - RType *input_mean_; - RType *input_scale_; - RType *input_variance_; + GType *input_bias_; + GType *input_mean_; + GType *input_scale_; + GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; - RType *new_bias_; - RType *new_scale_; + GType *new_bias_; + GType *new_scale_; }; #endif @@ -2022,7 +1996,6 @@ class FusionConvBNAddReluParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, *scope); axis_ = OpParam::GetAttr("axis", attrs); - output_ = OpParam::OutFrom(outputs, *scope); input_bias_ = OpParam::InputBiasFrom(inputs, *scope); input_mean_ = OpParam::InputMeanFrom(inputs, *scope); input_scale_ = OpParam::InputScaleFrom(inputs, *scope); @@ -2037,49 +2010,43 @@ class FusionConvBNAddReluParam : public ConvParam { } else if (keyY_ == keyBNY_) { bias_ = OpParam::InputXFrom(inputs, *scope); } - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, *scope); } - RType *Bias() const { return bias_; } + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - RType *Output() const { return output_; } - - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - const RType *InputMean() const { return input_mean_; } + const GType *InputMean() const { return input_mean_; } - const RType *InputScale() const { return input_scale_; } + const GType *InputScale() const { return input_scale_; } - const RType *InputVariance() const { return input_variance_; } + const GType *InputVariance() const { return input_variance_; } const float &Epsilon() const { return epsilon_; } const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } - const RType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_; } - const RType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_; } protected: - RType *bias_; + GType *bias_; int axis_; - RType *output_; - RType *input_bias_; - RType *input_mean_; - RType *input_scale_; - RType *input_variance_; + GType *input_bias_; + GType *input_mean_; + GType *input_scale_; + GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; - RType *new_bias_; - RType *new_scale_; + GType *new_bias_; + GType *new_scale_; std::string keyBNY_; std::string keyX_; std::string keyY_; @@ -2097,50 +2064,44 @@ class FusionConvBNParam : public ConvParam { const VariableNameMap &outputs, const AttributeMap &attrs, Scope *scope) : ConvParam(inputs, outputs, attrs, scope) { - output_y_ = OpParam::OutputYFrom(outputs, *scope); input_bias_ = OpParam::InputBiasFrom(inputs, *scope); input_mean_ = OpParam::InputMeanFrom(inputs, *scope); input_scale_ = OpParam::InputScaleFrom(inputs, *scope); input_variance_ = OpParam::InputVarianceFrom(inputs, *scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutputYFrom(outputs, *scope); } - RType *Output() const { return output_y_; } - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - const RType *InputMean() const { return input_mean_; } + const GType *InputMean() const { return input_mean_; } - const RType *InputScale() const { return input_scale_; } + const GType *InputScale() const { return input_scale_; } - const RType *InputVariance() const { return input_variance_; } + const GType *InputVariance() const { return input_variance_; } const float &Epsilon() const { return epsilon_; } const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } - const RType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_; } - const RType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_; } protected: - RType *output_y_; - RType *input_bias_; - RType *input_mean_; - RType *input_scale_; - RType *input_variance_; + GType *input_bias_; + GType *input_mean_; + GType *input_scale_; + GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; - RType *new_bias_; - RType *new_scale_; + GType *new_bias_; + GType *new_scale_; }; #endif @@ -2157,56 +2118,49 @@ class FusionConvAddBNParam : public ConvParam { : ConvParam(inputs, outputs, attrs, scope) { bias_ = OpParam::InputYFrom(inputs, *scope); axis_ = OpParam::GetAttr("axis", attrs); - output_y_ = OpParam::OutputYFrom(outputs, *scope); input_bias_ = OpParam::InputBiasFrom(inputs, *scope); input_mean_ = OpParam::InputMeanFrom(inputs, *scope); input_scale_ = OpParam::InputScaleFrom(inputs, *scope); input_variance_ = OpParam::InputVarianceFrom(inputs, *scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutputYFrom(outputs, *scope); } - RType *Bias() const { return bias_; } + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - RType *Output() const { return output_y_; } - - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - const RType *InputMean() const { return input_mean_; } + const GType *InputMean() const { return input_mean_; } - const RType *InputScale() const { return input_scale_; } + const GType *InputScale() const { return input_scale_; } - const RType *InputVariance() const { return input_variance_; } + const GType *InputVariance() const { return input_variance_; } const float &Epsilon() const { return epsilon_; } const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } + void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } - - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } - const RType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_; } - const RType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_; } protected: - RType *bias_; + GType *bias_; int axis_; - RType *output_y_; - RType *input_bias_; - RType *input_mean_; - RType *input_scale_; - RType *input_variance_; + GType *input_bias_; + GType *input_mean_; + GType *input_scale_; + GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; - RType *new_bias_; - RType *new_scale_; + GType *new_bias_; + GType *new_scale_; }; #endif @@ -2221,50 +2175,44 @@ class FusionDWConvBNReluParam : public ConvParam { const VariableNameMap &outputs, const AttributeMap &attrs, Scope *scope) : ConvParam(inputs, outputs, attrs, scope) { - output_ = OpParam::OutFrom(outputs, *scope); input_bias_ = OpParam::InputBiasFrom(inputs, *scope); input_mean_ = OpParam::InputMeanFrom(inputs, *scope); input_scale_ = OpParam::InputScaleFrom(inputs, *scope); input_variance_ = OpParam::InputVarianceFrom(inputs, *scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, *scope); } - RType *Output() const { return output_; } - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - const RType *InputMean() const { return input_mean_; } + const GType *InputMean() const { return input_mean_; } - const RType *InputScale() const { return input_scale_; } + const GType *InputScale() const { return input_scale_; } - const RType *InputVariance() const { return input_variance_; } + const GType *InputVariance() const { return input_variance_; } const float &Epsilon() const { return epsilon_; } const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } - - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } - const RType *NewScale() const { return new_scale_; } + const GType *NewScale() const { return new_scale_; } - const RType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_; } protected: - RType *output_; - RType *input_bias_; - RType *input_mean_; - RType *input_scale_; - RType *input_variance_; + GType *input_bias_; + GType *input_mean_; + GType *input_scale_; + GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; - RType *new_bias_; - RType *new_scale_; + GType *new_bias_; + GType *new_scale_; }; #endif @@ -2280,50 +2228,44 @@ class FusionConvBNReluParam : public ConvParam { const VariableNameMap &outputs, const AttributeMap &attrs, Scope *scope) : ConvParam(inputs, outputs, attrs, scope) { - output_ = OpParam::OutFrom(outputs, *scope); input_bias_ = OpParam::InputBiasFrom(inputs, *scope); input_mean_ = OpParam::InputMeanFrom(inputs, *scope); input_scale_ = OpParam::InputScaleFrom(inputs, *scope); input_variance_ = OpParam::InputVarianceFrom(inputs, *scope); epsilon_ = OpParam::GetAttr("epsilon", attrs); momentum_ = OpParam::GetAttr("momentum", attrs); - // is_test_ = OpParam::GetAttr("is_test", attrs); + this->output_ = OpParam::OutFrom(outputs, *scope); } - RType *Output() const { return output_; } - const RType *InputBias() const { return input_bias_; } + const GType *InputBias() const { return input_bias_; } - const RType *InputMean() const { return input_mean_; } + const GType *InputMean() const { return input_mean_; } - const RType *InputScale() const { return input_scale_; } + const GType *InputScale() const { return input_scale_; } - const RType *InputVariance() const { return input_variance_; } + const GType *InputVariance() const { return input_variance_; } const float &Epsilon() const { return epsilon_; } const float &Momentum() const { return momentum_; } - const bool &IsTest() const { return is_test_; } + void SetNewScale(GType *new_scale) { new_scale_ = new_scale; } - void SetNewScale(RType *new_scale) { new_scale_ = new_scale; } + void SetNewBias(GType *new_bias) { new_bias_ = new_bias; } - void SetNewBias(RType *new_bias) { new_bias_ = new_bias; } + const GType *NewScale() const { return new_scale_; } - const RType *NewScale() const { return new_scale_; } - - const RType *NewBias() const { return new_bias_; } + const GType *NewBias() const { return new_bias_; } protected: - RType *output_; - RType *input_bias_; - RType *input_mean_; - RType *input_scale_; - RType *input_variance_; + GType *input_bias_; + GType *input_mean_; + GType *input_scale_; + GType *input_variance_; float epsilon_; float momentum_; - bool is_test_; - RType *new_bias_; - RType *new_scale_; + GType *new_bias_; + GType *new_scale_; }; #endif @@ -2380,15 +2322,15 @@ class DropoutParam : public OpParam { dropout_prob_ = GetAttr("dropout_prob", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } float DropoutProb() const { return dropout_prob_; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; float dropout_prob_; }; #endif @@ -2415,11 +2357,11 @@ class ConvTransposeParam : public OpParam { groups = GetAttr("groups", attrs); } - const RType *Input() const { return input_; } + const GType *Input() const { return input_; } - const RType *Filter() const { return filter_; } + const GType *Filter() const { return filter_; } - RType *Output() const { return output_; } + GType *Output() const { return output_; } const vector &Strides() const { return strides_; } @@ -2430,9 +2372,9 @@ class ConvTransposeParam : public OpParam { const int &Groups() const { return groups; } private: - RType *input_; - RType *output_; - RType *filter_; + GType *input_; + GType *output_; + GType *filter_; vector strides_; vector paddings_; vector dilations_; @@ -2471,16 +2413,16 @@ class FusionDeconvAddParam : public ConvTransposeParam { axis_ = OpParam::GetAttr("axis", attrs); output_ = OpParam::OutFrom(outputs, *scope); } - RType *Bias() const { return bias_; } + GType *Bias() const { return bias_; } const int &Axis() const { return axis_; } - RType *Output() const { return output_; } + GType *Output() const { return output_; } protected: - RType *bias_; + GType *bias_; int axis_; - RType *output_; + GType *output_; }; #endif @@ -2785,13 +2727,13 @@ class FlattenParam : public OpParam { out_ = OutFrom(outputs, *scope); axis = GetAttr("axis", attrs); } - const RType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + GType *Out() const { return out_; } const int &Axis() const { return axis; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; int axis; }; #endif @@ -2816,7 +2758,7 @@ class SplitParam : public OpParam { // out_ts_.push_back(*scope.FindVar(outs_[i])->GetMutable()); // } } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } std::vector Outs() const { return outs_; } int Axis() const { return axis; } int Num() const { return num; } @@ -2824,7 +2766,7 @@ class SplitParam : public OpParam { // std::vector OutTs() const { return out_ts_; } private: - RType *input_x_; + GType *input_x_; std::vector outs_; int axis; int num; @@ -2859,16 +2801,16 @@ class BilinearInterpParam : public OpParam { out_h_ = GetAttr("out_h", attrs); out_w_ = GetAttr("out_w", attrs); } - const RType *InputX() const { return input_x_; } - const RType *InputOutPutSize() const { return input_outsize_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + const GType *InputOutPutSize() const { return input_outsize_; } + GType *Out() const { return out_; } int OutH() const { return out_h_; } int OutW() const { return out_w_; } private: - RType *input_x_; - RType *input_outsize_; - RType *out_; + GType *input_x_; + GType *input_outsize_; + GType *out_; int out_h_; int out_w_; }; @@ -2887,12 +2829,12 @@ class ShapeParam : public OpParam { input_ = InputFrom(inputs, *scope); out_ = OutFrom(outputs, *scope); } - const RType *Input() const { return input_; } - RType *Out() const { return out_; } + const GType *Input() const { return input_; } + GType *Out() const { return out_; } private: - RType *input_; - RType *out_; + GType *input_; + GType *out_; }; #endif @@ -2913,9 +2855,9 @@ class TopKParam : public OpParam { } public: - RType *input_; - RType *output_; - RType *indices_; + GType *input_; + GType *output_; + GType *indices_; int k_; }; #endif // TOP_K_OP @@ -2937,8 +2879,8 @@ class CastParam : public OpParam { } public: - RType *input_; - RType *output_; + GType *input_; + GType *output_; int input_type_; int output_type_; }; @@ -2975,9 +2917,9 @@ class QuantizeParam : public OpParam { GType *input_; // op output GType *output_; - RType *online_scale_; + GType *online_scale_; // quantize offline scale - RType *offline_scale_; + GType *offline_scale_; // if offine scale or not bool offline_ = false; // round method type @@ -3012,7 +2954,7 @@ class DequantizeParam : public OpParam { GType *input_; // op output GType *output_; - RType *activation_scale_; + GType *activation_scale_; float weight_scale_; }; #endif @@ -3042,10 +2984,10 @@ class FusionDequantBNParam : public DequantizeParam { public: // batch norm - RType *bn_mean_; - RType *bn_variance_; - RType *bn_scale_; - RType *bn_bias_; + GType *bn_mean_; + GType *bn_variance_; + GType *bn_scale_; + GType *bn_bias_; float epsilon_; }; #endif @@ -3072,7 +3014,7 @@ class FusionDequantAddBNParam : public FusionDequantBNParam { public: // elementwise add int axis_; - RType *bias_; + GType *bias_; }; #endif @@ -3101,9 +3043,9 @@ class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam { } public: - RType *online_scale_; + GType *online_scale_; // quantize offline scale - RType *offline_scale_; + GType *offline_scale_; // if offine scale or not bool offline_ = false; // round method type @@ -3269,24 +3211,6 @@ class LogicalUnaryParam : public OpParam { }; #endif // LOGICAL_NOT_OP -// #ifdef WHILE_OP -// template -// class WhileParam : public OpParam { -// public: -// WhileParam(const VariableNameMap &inputs, -// const VariableNameMap &outputs, const AttributeMap &attrs, -// const Scope &scope) : OpParam(inputs, outputs, attrs, scope) { -// cond_ = OpParam::GetVarValue("Condition", inputs, -// scope); block_desc_ = OpParam::GetAttr("sub_block", attrs); -// } -// -// public: -// framework::LoDTensor *cond_; -// const framework::BlockDesc *block_desc_; -// }; -// #endif // WHILE_OP - #ifdef WRITE_TO_ARRAY_OP template class WriteToArrayParam : public OpParam { @@ -3365,17 +3289,17 @@ class IncrementParam : public OpParam { : OpParam(inputs, outputs, attrs, scope) { input_x_ = InputXFrom(inputs, *scope); output_ = OutFrom(outputs, *scope); - step_ = OpParam::GetAttr("step", attrs); + step_ = OpParam::GetAttr("step", attrs); } const GType *InputX() const { return input_x_; } GType *Out() const { return output_; } - int Step() const { return step_; } + float Step() const { return step_; } public: GType *input_x_; GType *output_; - int step_; + float step_; }; #endif // INCREMENT_OP #ifdef PAD2D_OP diff --git a/src/operators/pad2d_op.cpp b/src/operators/pad2d_op.cpp index e7eda00d0830f719f8d7aa76ab77544b585d9b45..8a771c36a50f5a1b458df38d73ed93be61859cd4 100644 --- a/src/operators/pad2d_op.cpp +++ b/src/operators/pad2d_op.cpp @@ -19,14 +19,15 @@ namespace paddle_mobile { namespace operators { template -void Pad2dOp::InferShape() const { - auto input_dims = this->param_.InputX()->dims(); - auto input_n = input_dims[0]; - auto input_c = input_dims[1]; - auto input_h = input_dims[2]; - auto input_w = input_dims[3]; - - this->param_.Out()->Resize({input_n, input_c, input_h + 1, input_w + 1}); +void Pad2DOp::InferShape() const { + auto input_dims = this->param_.input_->dims(); + const auto &paddings = this->param_.paddings_; + PADDLE_MOBILE_ENFORCE(paddings.size() == 4, + "Size of paddings should be equal to 4."); + + input_dims[2] += paddings[0] + paddings[1]; + input_dims[3] += paddings[2] + paddings[3]; + this->param_.output_->Resize(input_dims); } } // namespace operators @@ -34,10 +35,10 @@ void Pad2dOp::InferShape() const { namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU -REGISTER_OPERATOR_CPU(pad2d, ops::Pad2dOp); +REGISTER_OPERATOR_CPU(pad2d, ops::Pad2DOp); #endif #ifdef PADDLE_MOBILE_FPGA -REGISTER_OPERATOR_FPGA(pad2d, ops::Pad2dOp); +REGISTER_OPERATOR_FPGA(pad2d, ops::Pad2DOp); #endif -#endif +#endif // PAD2D_OP diff --git a/src/operators/pad2d_op.h b/src/operators/pad2d_op.h index 761e2b837d34b8d51629b883a8cd6797037e5d9b..1a80cbac40f0c9bea36283373763f03489d073d2 100644 --- a/src/operators/pad2d_op.h +++ b/src/operators/pad2d_op.h @@ -17,33 +17,16 @@ limitations under the License. */ #pragma once #include - #include "framework/operator.h" #include "operators/kernel/pad2d_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { -using framework::AttributeMap; -using framework::OperatorWithKernel; -using framework::Scope; -using std::string; -template -class Pad2dOp - : public OperatorWithKernel, - operators::Pad2dKernel> { - public: - Pad2dOp(const string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, const AttributeMap &attrs, - std::shared_ptr scope) - : OperatorWithKernel, - operators::Pad2dKernel>( - type, inputs, outputs, attrs, scope) {} - void InferShape() const override; - - private: -}; + +DECLARE_OPERATOR(Pad2D, Pad2DParam, Pad2DKernel); + } // namespace operators } // namespace paddle_mobile -#endif +#endif // PAD2D_OP diff --git a/src/operators/polygon_box_transform_op.h b/src/operators/polygon_box_transform_op.h index e20765f715106d4b3c8a182d52e3ab135637c9e9..a4d1975e58b0374e776f3995fc1803419cacbfd2 100644 --- a/src/operators/polygon_box_transform_op.h +++ b/src/operators/polygon_box_transform_op.h @@ -36,7 +36,7 @@ class PolygonBoxTransformOp PolygonBoxTransformOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, PolygonBoxTransformParam, operators::PolygonBoxTransformKernel>( diff --git a/src/operators/pool_op.h b/src/operators/pool_op.h index 8f3957e29ee0802576f604900f8d15f86a864d53..861430f10bab03941aea643fabc3937c18f71376 100644 --- a/src/operators/pool_op.h +++ b/src/operators/pool_op.h @@ -24,19 +24,17 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using framework::AttributeMap; -using framework::OperatorWithKernel; -using framework::Scope; -using std::string; + template -class PoolOp : public OperatorWithKernel, - operators::PoolKernel> { +class PoolOp : public framework::OperatorWithKernel< + DeviceType, PoolParam, + operators::PoolKernel> { public: - PoolOp(const string &type, const VariableNameMap &inputs, + PoolOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, - std::shared_ptr scope) - : OperatorWithKernel, - operators::PoolKernel>( + framework::Scope *scope) + : framework::OperatorWithKernel, + operators::PoolKernel>( type, inputs, outputs, attrs, scope) {} void InferShape() const override; diff --git a/src/operators/prelu_op.h b/src/operators/prelu_op.h index 5d0458f896941ece4208ca4b4931db189b4f436e..92c2e7e62040d753d0ac40e6bc82afb5c1082e9d 100644 --- a/src/operators/prelu_op.h +++ b/src/operators/prelu_op.h @@ -34,7 +34,7 @@ class PReluOp : public framework::OperatorWithKernel< public: PReluOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::PReluKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/prior_box_op.h b/src/operators/prior_box_op.h index f7e26430a0536cde011de14f670a9f46b8f517c1..67d0cc6865fdc722a0191bc540a4d69c34ebedba 100644 --- a/src/operators/prior_box_op.h +++ b/src/operators/prior_box_op.h @@ -34,8 +34,7 @@ class PriorBoxOp : public framework::OperatorWithKernel< public: PriorBoxOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::PriorBoxKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/quantize_op.h b/src/operators/quantize_op.h index a2294cd55534382367e6b592dd1545ec3590d9c9..253113ad4bf04ba7bbc45e838211e41d2e2811fd 100644 --- a/src/operators/quantize_op.h +++ b/src/operators/quantize_op.h @@ -31,8 +31,7 @@ class QuantizeOp : public framework::OperatorWithKernel< public: QuantizeOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::QuantizeKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/reshape2_op.h b/src/operators/reshape2_op.h index 3a06c2b9b90233b6ad0bacb6176f4cc274ff1cc0..19c5e59f71d02228af019e8edf1d9f7edc75811d 100644 --- a/src/operators/reshape2_op.h +++ b/src/operators/reshape2_op.h @@ -34,8 +34,7 @@ class Reshape2Op : public framework::OperatorWithKernel< public: Reshape2Op(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::Reshape2Kernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/reshape_op.h b/src/operators/reshape_op.h index 3109303ff0e6007d0dbec133102924ff7bb30306..67e86044ea63a250fa5c944c4470d72c3b1c1bd7 100644 --- a/src/operators/reshape_op.h +++ b/src/operators/reshape_op.h @@ -34,8 +34,7 @@ class ReshapeOp : public framework::OperatorWithKernel< public: ReshapeOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::ReshapeKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/resize_op.h b/src/operators/resize_op.h index 954b3a82f8d2b5ccba242045c3d5e0f28553d484..6088ad4f51d91022990a999a2d557a80d3853253 100644 --- a/src/operators/resize_op.h +++ b/src/operators/resize_op.h @@ -34,7 +34,7 @@ class ResizeOp : public framework::OperatorWithKernel< public: ResizeOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::ResizeKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/scale_op.h b/src/operators/scale_op.h index 56265259fe3a10feda67cc5c5732b2ba44e0730e..aacacd92453449f32542a26ffd60f54b464a1483 100644 --- a/src/operators/scale_op.h +++ b/src/operators/scale_op.h @@ -34,7 +34,7 @@ class ScaleOp : public framework::OperatorWithKernel< public: ScaleOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::ScaleKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/sequence_ops/sequence_expand_op.h b/src/operators/sequence_ops/sequence_expand_op.h index cd62bbefc703ed2642e076913e2538c1621c1082..f854272d7b2c424563ab70dba8a0f1113e038cf6 100644 --- a/src/operators/sequence_ops/sequence_expand_op.h +++ b/src/operators/sequence_ops/sequence_expand_op.h @@ -32,7 +32,7 @@ class SequenceExpandOp : public framework::OperatorWithKernel< SequenceExpandOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, SequenceExpandParam, operators::SequenceExpandKernel>( diff --git a/src/operators/sequence_ops/sequence_pool_op.h b/src/operators/sequence_ops/sequence_pool_op.h index 724572936643abe071147edbdbee0053a29f4c20..aae892f9f3bb87d52100ee148d5653b4ab6fa657 100644 --- a/src/operators/sequence_ops/sequence_pool_op.h +++ b/src/operators/sequence_ops/sequence_pool_op.h @@ -31,8 +31,7 @@ class SequencePoolOp : public framework::OperatorWithKernel< public: SequencePoolOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, SequencePoolParam, operators::SequencePoolKernel>(type, inputs, outputs, diff --git a/src/operators/sequence_ops/sequence_softmax_op.h b/src/operators/sequence_ops/sequence_softmax_op.h index 92090ba802cc8ea97bc87f0fe9567b319c3d4948..f0578f6ed36e770254b7fca9925fa0a41daefa52 100644 --- a/src/operators/sequence_ops/sequence_softmax_op.h +++ b/src/operators/sequence_ops/sequence_softmax_op.h @@ -32,7 +32,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel< SequenceSoftmaxOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, SoftmaxParam, operators::SequenceSoftmaxKernel>( diff --git a/src/operators/shape_op.h b/src/operators/shape_op.h index 116751c48e9ca3cc9ec936b1bcbaa72b6950bbc5..05bc611bc555429458f04fc18e18b9475424c948 100644 --- a/src/operators/shape_op.h +++ b/src/operators/shape_op.h @@ -34,7 +34,7 @@ class ShapeOp : public framework::OperatorWithKernel< public: ShapeOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::ShapeKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/slice_op.h b/src/operators/slice_op.h index c45061696577dbe6948fb9cab7edebbaf8e15f2f..0d01705f7d7fb98f47a6960c047422e06dc68f7b 100644 --- a/src/operators/slice_op.h +++ b/src/operators/slice_op.h @@ -34,7 +34,7 @@ class SliceOp : public framework::OperatorWithKernel< public: SliceOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::SliceKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/softmax_op.cpp b/src/operators/softmax_op.cpp index e605864706a6c59a35205b3072dd432b009c5d1f..e4e6a8cf30ce946a2bf9f84ee66f06c651bfac73 100644 --- a/src/operators/softmax_op.cpp +++ b/src/operators/softmax_op.cpp @@ -21,6 +21,7 @@ namespace operators { template void SoftmaxOp::InferShape() const { this->param_.Out()->Resize(this->param_.InputX()->dims()); + this->param_.Out()->set_lod(this->param_.InputX()->lod()); } } // namespace operators diff --git a/src/operators/softmax_op.h b/src/operators/softmax_op.h index 422213feeaf2bc2301832de2f9c69827342a5062..2f9285a21d748eca90858f0d41a7998b9d2e95d2 100644 --- a/src/operators/softmax_op.h +++ b/src/operators/softmax_op.h @@ -31,8 +31,7 @@ class SoftmaxOp : public framework::OperatorWithKernel< public: SoftmaxOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel, operators::SoftmaxKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/split_op.h b/src/operators/split_op.h index fc733c18520b971107e00003b3107b8c0aa9b36d..4801defb496c1022adc46950b8781af723b90513 100644 --- a/src/operators/split_op.h +++ b/src/operators/split_op.h @@ -34,7 +34,7 @@ class SplitOp : public framework::OperatorWithKernel< public: SplitOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::SplitKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/sum_op.h b/src/operators/sum_op.h index aad8e8322b60d0e931215c9d48d97862f9b14107..3ee5465fc8c55ae7f75d90124a06664225acf153 100644 --- a/src/operators/sum_op.h +++ b/src/operators/sum_op.h @@ -30,7 +30,7 @@ class SumOp : public framework::OperatorWithKernel< public: SumOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::SumKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/top_k_op.cpp b/src/operators/top_k_op.cpp index cd902899a98a76fcc1bd471762eba006928df67c..d5cf6a37e9bfbf047592f9f11bf145748fc11b2d 100644 --- a/src/operators/top_k_op.cpp +++ b/src/operators/top_k_op.cpp @@ -27,6 +27,8 @@ void TopKOp::InferShape() const { dims[dims.size() - 1] = k; this->param_.output_->Resize(dims); this->param_.indices_->Resize(dims); + this->param_.output_->set_lod(this->param_.input_->lod()); + this->param_.indices_->set_lod(this->param_.input_->lod()); } } // namespace operators diff --git a/src/operators/top_k_op.h b/src/operators/top_k_op.h index cabae961d1d60fc47aa58669d3e1f1833bb68691..4c182d6ffec8a9d6685f516ff0579b0744719f5c 100644 --- a/src/operators/top_k_op.h +++ b/src/operators/top_k_op.h @@ -31,7 +31,7 @@ class TopKOp : public framework::OperatorWithKernel< public: TopKOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) + framework::Scope *scope) : framework::OperatorWithKernel, operators::TopKKernel>( type, inputs, outputs, attrs, scope) {} diff --git a/src/operators/transpose2_op.h b/src/operators/transpose2_op.h index f1339cc59e0c71a232eddd5dcef47f62994b80da..2552688ca6a3a0bb102842e7edbcc1ebdc777662 100644 --- a/src/operators/transpose2_op.h +++ b/src/operators/transpose2_op.h @@ -34,8 +34,7 @@ class Transpose2Op : public framework::OperatorWithKernel< public: Transpose2Op(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, Transpose2Param, operators::Transpose2Kernel>(type, inputs, outputs, diff --git a/src/operators/transpose_op.h b/src/operators/transpose_op.h index eb98ce235491632aa1149acc158552955c2c1e0c..cf03cb382570ccb8d7278546ab3b191e70a0792d 100644 --- a/src/operators/transpose_op.h +++ b/src/operators/transpose_op.h @@ -34,8 +34,7 @@ class TransposeOp : public framework::OperatorWithKernel< public: TransposeOp(const std::string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, - const framework::AttributeMap &attrs, - std::shared_ptr scope) + const framework::AttributeMap &attrs, framework::Scope *scope) : framework::OperatorWithKernel< DeviceType, TransposeParam, operators::TransposeKernel>(type, inputs, outputs, diff --git a/test/common/test_gemm_accuracy.cpp b/test/common/test_gemm_accuracy.cpp index 93cea2fd362ea3be42dbc5f53392fb47ad47d1d4..174459d3f58e82b85b5b189e8da6c0c9cb980a13 100644 --- a/test/common/test_gemm_accuracy.cpp +++ b/test/common/test_gemm_accuracy.cpp @@ -18,7 +18,7 @@ limitations under the License. */ #include "../test_helper.h" #include "common/log.h" #include "memory/t_malloc.h" -#include "operators/math/gemm.h" +#include "operators/math/gemm/cblas.h" #define a(i, j) a[(i)*lda + (j)] #define b(i, j) b[(i)*ldb + (j)] @@ -36,10 +36,12 @@ void print_matrix(int m, int n, int ldc, float *c) { std::cout << std::endl; } -int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { - int lda = k; - int ldb = n; - int ldc = n; +int do_sgemm(int m, int n, int k, int pr) { + const float alpha = 1.f; + const float beta = 0.f; + const int lda = k; + const int ldb = n; + const int ldc = n; float *a = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * k)); @@ -49,24 +51,19 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); float *c1 = static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m * n)); - float *scale = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); - float *bias = - static_cast(paddle_mobile::memory::Alloc(sizeof(float) * m)); - srand(unsigned(time(0))); + std::mt19937 rng(111); + std::uniform_real_distribution uniform_dist(0, 1); + const float lower = -10.f; + const float upper = 10.f; + for (int i = 0; i < m * k; ++i) { - a[i] = t1 + rand() % t2; + a[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } for (int i = 0; i < k * n; ++i) { - b[i] = t1 + rand() % t2; - } - for (int i = 0; i < m; ++i) { - scale[i] = t1 + rand() % t2; - } - for (int i = 0; i < m; ++i) { - bias[i] = t1 + rand() % t2; + b[i] = static_cast(uniform_dist(rng) * (upper - lower) + lower); } + memcpy(c, c1, sizeof(float) * m * n); for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { @@ -74,25 +71,20 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { for (int p = 0; p < k; p++) { r += a(i, p) * b(p, j); } - r *= scale[i]; - r += bias[i]; - if (relu && (r < 0)) { - r = 0; - } - c1(i, j) = r; + c1(i, j) = alpha * r; } } - paddle_mobile::operators::math::Gemm gemm; - gemm.SgemmWithBn(m, n, k, 1, a, lda, b, ldb, 0.3, c, ldc, relu, scale, bias, - nullptr); - int eq = 0; - int neq = 0; + std::cout << "run cblas_sgemm..." << std::endl; + paddle_mobile::operators::math::cblas_sgemm(false, false, m, n, k, alpha, a, + lda, b, ldb, 0.f, c, ldc); + + std::cout << "compare results..." << std::endl; for (int i = 0; i < m * n; ++i) { - if (static_cast(c[i]) == static_cast(c1[i])) { - ++eq; - } else { - ++neq; + if (abs(c[i] - c1[i]) >= 1e-2) { + std::cout << "c[" << i << "] != c1[" << i << "]: " << c[i] << " vs " + << c1[i] << std::endl; + exit(1); } } @@ -107,33 +99,36 @@ int do_sgemm(int m, int n, int k, bool relu, int t1, int t2, int pr) { print_matrix(m, n, ldc, c1); } - std::cout << "mnk=" << m << " " << n << " " << k << " relu=" << relu - << " eq=" << eq << " neq=" << neq << std::endl; - - PADDLE_MOBILE_ENFORCE(neq == 0, "The execution of do_sgemm is failed!"); - paddle_mobile::memory::Free(a); paddle_mobile::memory::Free(b); paddle_mobile::memory::Free(c); paddle_mobile::memory::Free(c1); - paddle_mobile::memory::Free(scale); - paddle_mobile::memory::Free(bias); return 0; } -int main() { - do_sgemm(9, 9, 9, true, 10, 10, 10); - do_sgemm(10, 6, 12, false, 10, 10, 0); - do_sgemm(512, 256, 384, false, 10, 10, 0); - do_sgemm(1366, 768, 256, false, 10, 10, 0); - do_sgemm(1255, 755, 333, false, 10, 10, 0); - do_sgemm(555, 777, 999, false, 10, 10, 0); - - do_sgemm(10, 6, 12, true, -4, 10, 0); - do_sgemm(512, 256, 384, true, -4, 10, 0); - do_sgemm(1366, 768, 256, true, -4, 10, 0); - do_sgemm(1255, 755, 333, true, -4, 10, 0); - do_sgemm(555, 777, 999, true, -4, 10, 0); +int main(int argc, char *argv[]) { + do_sgemm(1, 1, 1, 1); + + do_sgemm(9, 9, 1, 1); + do_sgemm(999, 99, 1, 0); + do_sgemm(999, 1, 1, 0); + do_sgemm(1, 9, 9, 1); + do_sgemm(1, 99, 999, 0); + do_sgemm(1, 1, 999, 0); + + do_sgemm(9, 9, 9, 1); + do_sgemm(10, 6, 12, 1); + do_sgemm(512, 256, 384, 0); + do_sgemm(1366, 768, 256, 0); + do_sgemm(1255, 755, 333, 0); + do_sgemm(555, 777, 999, 0); + + do_sgemm(10, 6, 12, 1); + do_sgemm(512, 256, 384, 0); + do_sgemm(1366, 768, 256, 0); + do_sgemm(1255, 755, 333, 0); + do_sgemm(555, 777, 999, 0); + return 0; } diff --git a/test/executor_for_test.h b/test/executor_for_test.h index 947f41032030f51e7522f4305ac395dd0c1fc211..bcb5006084dff0713cde15acd90514f3facf5ce5 100644 --- a/test/executor_for_test.h +++ b/test/executor_for_test.h @@ -57,31 +57,27 @@ class Executor4Test : public Executor { LOG(paddle_mobile::LogLevel::kLOG_ERROR) << "program_desc_ == nullptr"; } - const std::vector> blocks = + const std::vector> &blocks = this->program_desc_->Blocks(); - for (int block_id = 0; block_id < blocks.size(); ++block_id) { - std::vector> ops = blocks[block_id]->Ops(); - for (int i = 0; i < ops.size(); ++i) { - auto op = ops[i]; - if (op->Type() == op_type) { - DLOG << "匹配到: " << op->Type(); - - /// test first meeting op in program - std::shared_ptr> - op_ptr = - paddle_mobile::framework::OpRegistry::CreateOp( - op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), this->program_.scope); - this->ops_of_block_[block_id].push_back(op_ptr); - break; - } + std::vector> ops = blocks[0]->Ops(); + for (int i = 0; i < ops.size(); ++i) { + auto op = ops[i]; + if (op->Type() == op_type) { + DLOG << "匹配到: " << op->Type(); + + /// test first meeting op in program + std::shared_ptr> + op_ptr = paddle_mobile::framework::OpRegistry::CreateOp( + op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), + this->program_.scope.get()); + this->ops_of_block0_.push_back(op_ptr); + break; } } + this->InitMemory(); - for (const auto &ops : this->ops_of_block_) { - for (const auto &op : ops) { - op->Init(); - } + for (const auto &op : this->ops_of_block0_) { + op->Init(); } } @@ -90,7 +86,7 @@ class Executor4Test : public Executor { const vector &input_names, const vector &output_names, const vector &ddims) { - auto scope = this->program_.scope; + auto scope = this->program_.scope.get(); size_t input_size = input_names.size(); size_t out_size = output_names.size(); @@ -114,10 +110,8 @@ class Executor4Test : public Executor { output_tensor_sptrs[i].reset(output_tensors[i]); } - for (auto &ops : this->ops_of_block_) { - for (auto &op : ops) { - op->Run(); - } + for (auto &op : this->ops_of_block0_) { + op->Run(); } return output_tensor_sptrs; @@ -125,7 +119,7 @@ class Executor4Test : public Executor { std::shared_ptr Predict(const Tensor &t, string input, string output, const DDim &dDim) { - auto scope = this->program_.scope; + auto scope = this->program_.scope.get(); Variable *g_feed_value = scope->Var(input); auto tensor = g_feed_value->GetMutable(); tensor->ShareDataWith(t); @@ -134,11 +128,10 @@ class Executor4Test : public Executor { auto *output_tensor = con_output->GetMutable(); output_tensor->mutable_data(dDim); - for (auto &ops : this->ops_of_block_) { - for (auto &op : ops) { - op->Run(); - } + for (auto &op : this->ops_of_block0_) { + op->Run(); } + return std::make_shared( paddle_mobile::framework::Tensor(*output_tensor)); } diff --git a/test/fpga/test_rfcn_api.cpp b/test/fpga/test_rfcn_api.cpp index 2268fc46e5eb98eeff781e6de3a57e5efb911d3d..f787d8f9acfe85ead101aeb16a4fbebe1aefee65 100644 --- a/test/fpga/test_rfcn_api.cpp +++ b/test/fpga/test_rfcn_api.cpp @@ -12,19 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef PADDLE_MOBILE_FPGA -#define PADDLE_MOBILE_FPGA -#endif -#include #include -#include "../../src/io/paddle_inference_api.h" +#include "../test_helper.h" +#include "../test_include.h" -using namespace paddle_mobile; -using namespace paddle_mobile::fpga; +#ifdef PADDLE_MOBILE_FPGA_V1 +#include "fpga/V1/api.h" +#endif +#ifdef PADDLE_MOBILE_FPGA_V2 +#include "fpga/V2/api.h" +#endif -static const char *g_image = "../models/rfcn/data.bin"; -static const char *g_model = "../models/rfcn/model"; -static const char *g_param = "../models/rfcn/params"; +#include void readStream(std::string filename, char *buf) { std::ifstream in; @@ -38,137 +37,116 @@ void readStream(std::string filename, char *buf) { auto length = in.tellg(); // report location (this is the length) in.seekg(0, std::ios::beg); // go back to the beginning in.read(buf, length); + DLOG << length; in.close(); } -PaddleMobileConfig GetConfig() { - PaddleMobileConfig config; - config.precision = PaddleMobileConfig::FP32; - config.device = PaddleMobileConfig::kFPGA; - config.prog_file = g_model; - config.param_file = g_param; - config.thread_num = 1; - config.batch_size = 1; - config.optimize = true; - config.lod_mode = true; - config.quantification = false; - return config; -} - -PaddleMobileConfig GetConfig1() { - PaddleMobileConfig config; - config.precision = PaddleMobileConfig::FP32; - config.device = PaddleMobileConfig::kFPGA; - config.model_dir = "../models/resnet50"; - config.thread_num = 1; - config.batch_size = 1; - config.optimize = true; - config.quantification = false; - return config; +void convert_to_chw(int16_t **data_in, int channel, int height, int width, + int num, int16_t *data_tmp) { + int64_t amount_per_side = width * height; + for (int n = 0; n < num; n++) { + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + for (int c = 0; c < channel; c++) { + *(data_tmp + n * amount_per_side * channel + c * amount_per_side + + width * h + w) = *((*data_in)++); + } + } + } + } } -int main() { - open_device(); - - PaddleMobileConfig config = GetConfig(); - auto predictor = - CreatePaddlePredictor(config); - - std::cout << "Finishing loading model" << std::endl; - - float img_info[3] = {432, 1280, 1.0f}; - int img_length = 432 * 1280 * 3; - auto img = reinterpret_cast(fpga_malloc(img_length * sizeof(float))); - readStream(g_image, reinterpret_cast(img)); - - std::cout << "Finishing initializing data" << std::endl; - struct PaddleTensor t_img_info, t_img; - t_img.dtypeid = typeid(float); - t_img_info.layout = LAYOUT_HWC; - t_img_info.shape = std::vector({1, 3}); - t_img_info.name = "Image information"; - t_img_info.data.Reset(img_info, 3 * sizeof(float)); - - t_img.dtypeid = typeid(float); - t_img.layout = LAYOUT_HWC; - t_img.shape = std::vector({1, 432, 1280, 3}); - t_img.name = "Image information"; - t_img.data.Reset(img, img_length * sizeof(float)); - predictor->FeedPaddleTensors({t_img_info, t_img}); - - std::cout << "Finishing feeding data " << std::endl; - - predictor->Predict_From_To(0, -1); - std::cout << "Finishing predicting " << std::endl; - - std::vector v; // No need to initialize v - predictor->FetchPaddleTensors(&v); // Old data in v will be cleared - std::cout << "Output number is " << v.size() << std::endl; - std::cout << "out[0] length " << v[0].data.length() << std::endl; - std::cout << "out[1] length " << v[1].data.length() << std::endl; - std::cout << "out[2] length " << v[2].data.length() << std::endl; - - auto post_nms = v[0].data.length() / sizeof(float) / 8; - for (int num = 0; num < post_nms; num++) { - for (int i = 0; i < 8; i++) { - auto p = reinterpret_cast(v[0].data.data()); - std::cout << p[num * 8 + i] << std::endl; - } +void dump_stride_half(std::string filename, Tensor input_tensor, + const int dumpnum, bool use_chw) { + // bool use_chw = true; + if (input_tensor.dims().size() != 4) return; + int c = (input_tensor.dims())[1]; + int h = (input_tensor.dims())[2]; + int w = (input_tensor.dims())[3]; + int n = (input_tensor.dims())[0]; + auto data_ptr = input_tensor.get_data(); + auto *data_ptr_16 = reinterpret_cast(data_ptr); + auto data_tmp = data_ptr_16; + if (use_chw) { + data_tmp = + reinterpret_cast(malloc(n * c * h * w * sizeof(int16_t))); + convert_to_chw(&data_ptr_16, c, h, w, n, data_tmp); } - for (int num = 0; num < post_nms; num++) { - for (int i = 0; i < 8; i++) { - auto p = reinterpret_cast(v[1].data.data()); - std::cout << p[num * 8 + i] << std::endl; - } + std::ofstream out(filename.c_str()); + float result = 0; + int stride = input_tensor.numel() / dumpnum; + stride = stride > 0 ? stride : 1; + for (int i = 0; i < input_tensor.numel(); i += stride) { + result = paddle_mobile::fpga::fp16_2_fp32(data_tmp[i]); + out << result << std::endl; } - for (int num = 0; num < post_nms; num++) { - for (int i = 0; i < 4; i++) { - auto p = reinterpret_cast(v[2].data.data()); - std::cout << p[num * 4 + i] << std::endl; - } + out.close(); + if (data_tmp != data_ptr_16) { + free(data_tmp); } - std::cout << "Finish getting vector values" << std::endl; - - //////////////////////////////////////////////////// +} - PaddleTensor tensor; - predictor->GetPaddleTensor("fetch2", &tensor); - for (int i = 0; i < post_nms; i++) { - auto p = reinterpret_cast(tensor.data.data()); - std::cout << p[+i] << std::endl; +void dump_stride_float(std::string filename, Tensor input_tensor, + const int dumpnum) { + auto data_ptr = reinterpret_cast(input_tensor.get_data()); + std::ofstream out(filename.c_str()); + float result = 0; + int stride = input_tensor.numel() / dumpnum; + stride = stride > 0 ? stride : 1; + for (int i = 0; i < input_tensor.numel(); i += stride) { + result = data_ptr[i]; + out << result << std::endl; } + out.close(); +} - ////////////////////////////////////////////////////// - - PaddleMobileConfig config1 = GetConfig1(); - auto predictor1 = - CreatePaddlePredictor(config1); - - std::cout << "Finishing loading model" << std::endl; - - int img_length1 = 224 * 224 * 3; - auto img1 = - reinterpret_cast(fpga_malloc(img_length1 * sizeof(float))); - - std::cout << "Finishing initializing data" << std::endl; +void dump_stride(std::string filename, Tensor input_tensor, const int dumpnum, + bool use_chw) { + static int i = 0; + if (input_tensor.numel() == 0) { + return; + } + if (input_tensor.type() == typeid(float)) { + DLOG << "op: " << i++ << ", float data " << input_tensor.numel(); - struct PaddleTensor t_img1; + dump_stride_float(filename, input_tensor, dumpnum); + } else { + DLOG << "op: " << i++ << ", half data " << input_tensor.numel(); - t_img1.dtypeid = typeid(float); - t_img1.layout = LAYOUT_HWC; - t_img1.shape = std::vector({1, 224, 224, 3}); - t_img1.name = "Image information"; - t_img1.data.Reset(img1, img_length1 * sizeof(float)); - predictor1->FeedPaddleTensors({t_img1}); - predictor1->Predict_From_To(0, -1); - std::cout << "Finishing predicting " << std::endl; + dump_stride_half(filename, input_tensor, dumpnum, use_chw); + } + DLOG << "dump input address: " << input_tensor.get_data(); +} - std::vector v1; // No need to initialize v - predictor1->FetchPaddleTensors(&v1); // Old data in v will be cleared - std::cout << "Output number is " << v1.size() << std::endl; - std::cout << "out[0] length " << v1[0].data.length() << std::endl; +static const char *g_rfcn_combine = "../models/rfcn"; +static const char *g_image_src_float = "../models/rfcn/data.bin"; +int main() { + paddle_mobile::fpga::open_device(); + paddle_mobile::PaddleMobile paddle_mobile; + + if (paddle_mobile.Load(std::string(g_rfcn_combine) + "/model", + std::string(g_rfcn_combine) + "/params", true, false, + 1, true)) { + float img_info[3] = {768, 1536, 768.0f / 960.0f}; + auto img = reinterpret_cast( + fpga::fpga_malloc(768 * 1536 * 3 * sizeof(float))); + readStream(g_image_src_float, reinterpret_cast(img)); + + std::vector v(3, nullptr); + paddle_mobile.FeedData(std::vector({img_info, img})); + paddle_mobile.Predict_To(-1); + + for (int i = 65; i < 69; i++) { + auto tensor_ptr = paddle_mobile.FetchResult(i); + std::string saveName = "rfcn_" + std::to_string(i); + paddle_mobile::fpga::fpga_invalidate((*tensor_ptr).get_data(), + tensor_ptr->numel() * sizeof(float)); + dump_stride(saveName, (*tensor_ptr), tensor_ptr->numel(), true); + } + // paddle_mobile.GetResults(&v); + DLOG << "Computation done"; + fpga::fpga_free(img); + } return 0; } diff --git a/test/net/test_benchmark.cpp b/test/net/test_benchmark.cpp index 31a0850c4d531d13f7960d9857b3721ee69c6d27..4c9a36dc26371d701b1bc62840b73b2fee295224 100644 --- a/test/net/test_benchmark.cpp +++ b/test/net/test_benchmark.cpp @@ -36,7 +36,10 @@ int main(int argc, char* argv[]) { paddle_mobile::PaddleMobile paddle_mobile; paddle_mobile.SetThreadNum(thread_num); auto time1 = time(); - if (paddle_mobile.Load(fluid_model, optimize)) { + // if (paddle_mobile.Load(fluid_model, optimize, false, 1, true)) { + if (paddle_mobile.Load(std::string(fluid_model) + "/model", + std::string(fluid_model) + "/params", optimize, false, + 1, true)) { auto time2 = time(); std::cout << "load cost :" << time_diff(time1, time2) << "ms\n"; paddle_mobile::framework::Tensor input; @@ -52,13 +55,14 @@ int main(int argc, char* argv[]) { paddle_mobile::framework::make_ddim(dims); SetupTensor(&input, in_shape, 0.f, 255.f); // warmup - for (int i = 0; i < 10; ++i) { + for (int i = 0; i < 2; ++i) { paddle_mobile.Predict(input); } auto time3 = time(); for (int i = 0; i < 10; ++i) { paddle_mobile.Predict(input); } + auto time4 = time(); std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n"; std::ostringstream os("output tensor size: "); @@ -68,7 +72,7 @@ int main(int argc, char* argv[]) { os << ", " << output->data()[i]; } std::string output_str = os.str(); - std::cout << output_str << std::endl; + // std::cout << output_str << std::endl; } return 0; } diff --git a/test/net/test_ocr.cpp b/test/net/test_ocr.cpp index 5b8b245e9cf2fa69b3d4691ef9cd9509a9d26d01..661b6e5cbf30b625e50a1f68d07ffc85f53e06bf 100644 --- a/test/net/test_ocr.cpp +++ b/test/net/test_ocr.cpp @@ -20,11 +20,11 @@ limitations under the License. */ void load_images(const char *image_dir, const char *images_list, std::vector *image_names, std::vector> *image_shapes) { - int height, width; + int channel, height, width; std::string filename; std::ifstream if_list(images_list, std::ios::in); while (!if_list.eof()) { - if_list >> height >> width >> filename; + if_list >> channel >> height >> width >> filename; image_shapes->push_back(std::make_pair(height, width)); image_names->push_back(filename); } @@ -32,20 +32,25 @@ void load_images(const char *image_dir, const char *images_list, } int main(int argc, char **argv) { - if (argc < 4) { - std::cerr << "Usage: ./test_ocr model_dir image_dir images_list." - << std::endl; + if (argc < 5) { + std::cerr + << "Usage: ./test_ocr model_dir image_dir images_list output_name." + << std::endl; return 1; } char *model_dir = argv[1]; char *image_dir = argv[2]; char *images_list = argv[3]; + char *output_name = argv[4]; paddle_mobile::PaddleMobile paddle_mobile; - paddle_mobile.SetThreadNum(8); + paddle_mobile.SetThreadNum(1); auto isok = paddle_mobile.Load(std::string(model_dir) + "/model", std::string(model_dir) + "/params", true, false, 1, true); + // auto isok = paddle_mobile.Load(std::string(model_dir), false, + // false, 1, true); + DLOG << "pass init model"; std::vector image_names; std::vector> image_shapes; @@ -55,7 +60,7 @@ int main(int argc, char **argv) { for (int i = 0; i < image_names.size(); i++) { std::string file_name = image_names[i]; std::vector input_vec; - std::vector dims{1, 1, 48, 512}; + std::vector dims{1, 3, 48, 512}; dims[2] = image_shapes[i].first; dims[3] = image_shapes[i].second; // load input image @@ -64,23 +69,24 @@ int main(int argc, char **argv) { std::cerr << "shape = [" << dims[0] << ", " << dims[1] << ", " << dims[2] << ", " << dims[3] << "]" << std::endl; GetInput(img_path, &input_vec, dims); - framework::Tensor input(input_vec, framework::make_ddim(dims)); + // framework::Tensor input(input_vec, framework::make_ddim(dims)); // predict - paddle_mobile.Predict(input); - auto output_topk = paddle_mobile.Fetch("top_k_1.tmp_0"); - auto output_indices = paddle_mobile.Fetch("cast_68.tmp_0"); + // for (int j = 0; j < 10000; ++j) { + auto time3 = paddle_mobile::time(); + paddle_mobile.Predict(input_vec, dims); + auto output_topk = paddle_mobile.Fetch(output_name); + auto time4 = paddle_mobile::time(); + std::cerr << "average predict elapsed: " + << paddle_mobile::time_diff(time3, time4) << "ms" << std::endl; + // } + // print result - std::cerr << file_name << std::endl; + std::cerr << output_name << std::endl; std::cerr << output_topk->data()[0]; for (int j = 1; j < output_topk->numel(); ++j) { std::cerr << " " << output_topk->data()[j]; } std::cerr << std::endl; - std::cerr << output_indices->data()[0]; - for (int j = 1; j < output_indices->numel(); ++j) { - std::cerr << " " << output_indices->data()[j]; - } - std::cerr << std::endl; } return 0; } diff --git a/test/operators/test_batchnorm_op.cpp b/test/operators/test_batchnorm_op.cpp index f78aa4061205586e9d1540b65a8a3dbc32de6757..92cb7157c133f140f8b630c1edfe109e26631244 100644 --- a/test/operators/test_batchnorm_op.cpp +++ b/test/operators/test_batchnorm_op.cpp @@ -88,8 +88,8 @@ int TestBatchNormOp(const std::vector input_shape) { attrs["epsilon"].Set(eps); attrs["momentum"].Set(0.f); - auto *op = new operators::BatchNormOp("batch_norm", inputs, - outputs, attrs, scope); + auto *op = new operators::BatchNormOp( + "batch_norm", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_box_coder_op.cpp b/test/operators/test_box_coder_op.cpp index 721e691107c2c2d0117fdedecf219484556c9541..39b8257e66ee506db3259167755b26f84f8a07af 100644 --- a/test/operators/test_box_coder_op.cpp +++ b/test/operators/test_box_coder_op.cpp @@ -49,7 +49,7 @@ class TestBoxCoderOp { std::shared_ptr> boxcoder = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(boxcoder); } } @@ -59,7 +59,7 @@ class TestBoxCoderOp { std::shared_ptr predict_boxcoder(const Tensor &t1, const Tensor &t2, const Tensor &t3) { // feed - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *prior_box = scope->Var("concat_0.tmp_0"); auto tensor_x1 = prior_box->GetMutable(); tensor_x1->ShareDataWith(t1); diff --git a/test/operators/test_cast_op.cpp b/test/operators/test_cast_op.cpp index df6fab705bfb639f2b2d0b6d8c30bb86512b84d0..f330e07eafa1fab3da74b61bbdb29c50450610c9 100644 --- a/test/operators/test_cast_op.cpp +++ b/test/operators/test_cast_op.cpp @@ -81,8 +81,8 @@ int TestCastOp(const std::vector input_shape) { framework::AttributeMap attrs; attrs["in_dtype"].Set(TypeInt()); attrs["out_dtype"].Set(TypeInt()); - auto *op = - new operators::CastOp("cast", inputs, outputs, attrs, scope); + auto *op = new operators::CastOp("cast", inputs, outputs, attrs, + scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_concat_op.cpp b/test/operators/test_concat_op.cpp index 88ec06be6f1b5197669f7c580d935bb9d2475c5a..761d1ac51d6c5c3d47b6eb8ef3695262acb11eb2 100644 --- a/test/operators/test_concat_op.cpp +++ b/test/operators/test_concat_op.cpp @@ -27,7 +27,7 @@ using framework::Scope; using framework::make_ddim; template -void concat(const std::vector &input, LoDTensor &output, int axis) { +void concat(const std::vector &input, LoDTensor *output, int axis) { int num = input.size(); int rows = 1; @@ -45,7 +45,7 @@ void concat(const std::vector &input, LoDTensor &output, int axis) { } // computation - auto output_data = output.data(); + auto output_data = output->data(); int col_idx = 0; for (int j = 0; j < num; ++j) { int col_len = input_cols[j]; @@ -99,14 +99,14 @@ int TestConcatOP() { attrs["axis"].Set(axis_v); auto *op = new operators::ConcatOp("concat", inputs, outputs, - attrs, scope); + attrs, scope.get()); op->InferShape(); op->Run(); auto output = output_var->template Get(); const T *output_data = output->data(); LoDTensor output_cmp; output_cmp.mutable_data(output_shape); - concat(input_tensors, output_cmp, axis_v); + concat(input_tensors, &output_cmp, axis_v); const T *output_cmp_data = output_cmp.data(); // compare int eq = 0; diff --git a/test/operators/test_conv_bn_relu_op.cpp b/test/operators/test_conv_bn_relu_op.cpp index 6a09d838e0a30486569448726c255b1a6ba7f617..b51bdc07375d8dded738cf022781bc3c14d11d44 100644 --- a/test/operators/test_conv_bn_relu_op.cpp +++ b/test/operators/test_conv_bn_relu_op.cpp @@ -84,7 +84,7 @@ int TestConvBnReluOp(int in_channels, int in_height, int in_width, attrs["epsilon"].Set(1e-6); attrs["momentum"].Set(0.f); auto *op = new operators::FusionConvBNReluOp( - "fusion_conv_bn_relu", inputs, outputs, attrs, scope); + "fusion_conv_bn_relu", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); for (int i = 0; i < 10; ++i) { diff --git a/test/operators/test_conv_op.cpp b/test/operators/test_conv_op.cpp index 3a949daefeb89df1c72702f1207a0d0f0e652f93..c705e162fec91bebc8eb25008b66b979a7d70c06 100644 --- a/test/operators/test_conv_op.cpp +++ b/test/operators/test_conv_op.cpp @@ -182,7 +182,7 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels, attrs["groups"].Set(groups); auto *op = new operators::ConvOp("conv2d", inputs, outputs, attrs, - scope); + scope.get()); op->InferShape(); op->Init(); // struct timespec ts_begin, ts_end; @@ -206,11 +206,11 @@ int TestConvOp(int in_channels, int in_height, int in_width, int out_channels, const Otype *output_data = output->data(); Otype *output_cmp_data = output_cmp.data(); for (int i = 0; i < output->numel(); ++i) { - float gap = output_data[i] - output_cmp_data[i]; + float gap = abs(output_data[i] - output_cmp_data[i]); // PADDLE_MOBILE_ENFORCE(std::abs(gap / (output_data[i] + 1e-5)) < 1e-3, // "output[%d] = %d, output_cmp[%d] = %d", i, // output_data[i], i, output_cmp_data[i]); - if (gap > 1e-2 && std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) { + if (gap > 1e-2 && (gap / (abs(output_data[i]) + 1e-5) > 1e-2)) { std::cerr << "output_data[" << i << "] = " << output_data[i] << ", output_cmp_data[" << i << "] = " << output_cmp_data[i] << std::endl; @@ -228,39 +228,43 @@ int TestAll(const int in_channels, const int in_height, const int in_width, std::cerr << "in_channels=" << in_channels << ", in_height=" << in_height << ", in_width=" << in_width << ", out_channels=" << out_channels << ", groups=" << groups << std::endl; - // // kernel = 3, pad = 0, stride = 1 - // std::cerr << "float, kernel=3, pad=0, stride=1" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); - // // kernel = 3, pad = 1, stride = 1 - // std::cerr << "float, kernel=3, pad=1, stride=1" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); - // // kernel = 3, pad = 2, stride = 1 - // std::cerr << "float, kernel=3, pad=2, stride=1" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); - // // kernel = 3, pad = 5, stride = 1 - // std::cerr << "float, kernel=3, pad=5, stride=1" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); - // - // // kernel = 3, pad = 0, stride = 2 - // std::cerr << "float, kernel=3, pad=0, stride=2" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); - // // kernel = 3, pad = 1, stride = 2 - // std::cerr << "float, kernel=3, pad=1, stride=2" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); - // // kernel = 3, pad = 2, stride = 2 - // std::cerr << "float, kernel=3, pad=2, stride=2" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); - // // kernel = 3, pad = 5, stride = 2 - // std::cerr << "float, kernel=3, pad=5, stride=2" << std::endl; - // paddle_mobile::TestConvOp( - // in_channels, in_height, in_width, out_channels, groups); + std::cerr << "float, kernel=1, pad=0, stride=1" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + + // kernel = 3, pad = 0, stride = 1 + std::cerr << "float, kernel=3, pad=0, stride=1" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 3, pad = 1, stride = 1 + std::cerr << "float, kernel=3, pad=1, stride=1" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 3, pad = 2, stride = 1 + std::cerr << "float, kernel=3, pad=2, stride=1" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 3, pad = 5, stride = 1 + std::cerr << "float, kernel=3, pad=5, stride=1" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + + // kernel = 3, pad = 0, stride = 2 + std::cerr << "float, kernel=3, pad=0, stride=2" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 3, pad = 1, stride = 2 + std::cerr << "float, kernel=3, pad=1, stride=2" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 3, pad = 2, stride = 2 + std::cerr << "float, kernel=3, pad=2, stride=2" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); + // kernel = 3, pad = 5, stride = 2 + std::cerr << "float, kernel=3, pad=5, stride=2" << std::endl; + paddle_mobile::TestConvOp( + in_channels, in_height, in_width, out_channels, groups); #ifndef __aarch64__ // kernel = 3, pad = 0, stride = 1 @@ -338,6 +342,7 @@ int TestAll(const int in_channels, const int in_height, const int in_width, } int main() { + TestAll(16, 10, 10, 16, 16); TestAll(1, 5, 5, 1, 1); TestAll(1, 5, 5, 10, 1); TestAll(10, 5, 5, 10, 10); diff --git a/test/operators/test_dequantize_op.cpp b/test/operators/test_dequantize_op.cpp index 8e89d8f7af3694bcc4701c268451f28675db7fc9..41390d74a92d9cab615244600079373e84fb31d2 100644 --- a/test/operators/test_dequantize_op.cpp +++ b/test/operators/test_dequantize_op.cpp @@ -50,8 +50,8 @@ int TestDequqntizeOp() { framework::AttributeMap attrs; attrs["weight_scale"].Set(1.74); - auto* op = new operators::DequantizeOp("dequantize", inputs, - outputs, attrs, scope); + auto* op = new operators::DequantizeOp( + "dequantize", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Run(); auto output = output_var->template Get(); diff --git a/test/operators/test_dwconv_bn_relu_op.cpp b/test/operators/test_dwconv_bn_relu_op.cpp index 7fcf10d903e571ac7b0f5fb0a4b1214bf55327d1..8b2e6f06b2295c7716d2b4b6fc4cdd156668e2c8 100644 --- a/test/operators/test_dwconv_bn_relu_op.cpp +++ b/test/operators/test_dwconv_bn_relu_op.cpp @@ -87,7 +87,7 @@ int TestDWConvAddBnReluOp(int in_channels, int in_height, int in_width, attrs["momentum"].Set(0.f); auto *op = new operators::FusionDWConvBNReluOp( - "fusion_dwconv_bn_relu", inputs, outputs, attrs, scope); + "fusion_dwconv_bn_relu", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); for (int i = 0; i < 10; ++i) { diff --git a/test/operators/test_elementwise_sub_op.cpp b/test/operators/test_elementwise_sub_op.cpp index e1030852976a68db827ebb7629caf8bb199a2456..d07d42849b1f83b3cd30b969e820d42053ba81ea 100644 --- a/test/operators/test_elementwise_sub_op.cpp +++ b/test/operators/test_elementwise_sub_op.cpp @@ -47,7 +47,7 @@ class TestElementwiseSubOp { std::shared_ptr> lrn = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(lrn); } } @@ -56,7 +56,7 @@ class TestElementwiseSubOp { std::shared_ptr predict_bn(const Tensor &t1, const Tensor &t2) { // feed - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *x1_feed_value = scope->Var("tmp_0"); auto tensor_x1 = x1_feed_value->GetMutable(); tensor_x1->ShareDataWith(t1); diff --git a/test/operators/test_fill_constant_op.cpp b/test/operators/test_fill_constant_op.cpp index 9dc7bb13884efb8860a6670e088bd5af67c1f0ea..86a4bf0a3713bb4b8f1301ca9d0acf68f140060c 100644 --- a/test/operators/test_fill_constant_op.cpp +++ b/test/operators/test_fill_constant_op.cpp @@ -47,7 +47,7 @@ class TestFillConstantOp { std::shared_ptr> op_ptr = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(op_ptr); } } @@ -55,7 +55,7 @@ class TestFillConstantOp { } std::shared_ptr predict() { - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *output = scope->Var(output_var_name); auto *output_tensor = output->GetMutable(); diff --git a/test/operators/test_fusion_fc_op.cpp b/test/operators/test_fusion_fc_op.cpp index 97bf233155a7229caee68b67c6ff7b1314ec0c6e..2cc9bcd153f057415dd84ba6591f5e387c0028fa 100644 --- a/test/operators/test_fusion_fc_op.cpp +++ b/test/operators/test_fusion_fc_op.cpp @@ -103,7 +103,7 @@ int TestFcOP() { attrs["axis"].Set(1); operators::OperatorBase *op = nullptr; op = new operators::FusionFcOp("fusion_fc", inputs, outputs, attrs, - scope); + scope.get()); op->InferShape(); op->Run(); auto output = output_var->template Get(); diff --git a/test/operators/test_gru_op.cpp b/test/operators/test_gru_op.cpp index b11ec4f5f77aca2c4997153863e70b1a6b209c32..704a4df0294e46cd661d5b010398baa6fa40a740 100644 --- a/test/operators/test_gru_op.cpp +++ b/test/operators/test_gru_op.cpp @@ -71,8 +71,8 @@ int TestGruOp(int in_channels, int out_channels, std::string opname) { attrs["gate_activation"].SetString(std::string("sigmoid")); attrs["is_reverse"].Set(false); - auto *op = - new operators::GruOp("gru", inputs, outputs, attrs, scope); + auto *op = new operators::GruOp("gru", inputs, outputs, attrs, + scope.get()); op->InferShape(); op->Init(); for (int i = 0; i < 10; ++i) { diff --git a/test/operators/test_im2sequence_op.cpp b/test/operators/test_im2sequence_op.cpp index 3cd172d99bb1bb9c24f035d501dce362476909c2..247e6a466f9a591f434ef714117eeb31999de164 100644 --- a/test/operators/test_im2sequence_op.cpp +++ b/test/operators/test_im2sequence_op.cpp @@ -47,7 +47,7 @@ class TestIm2SequenceOp { std::shared_ptr> lrn = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(lrn); } } @@ -56,7 +56,7 @@ class TestIm2SequenceOp { std::shared_ptr predict_bn(const Tensor &t1) { // feed - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *x1_feed_value = scope->Var("conv2d_19.tmp_1"); auto tensor_x1 = x1_feed_value->GetMutable(); tensor_x1->ShareDataWith(t1); diff --git a/test/operators/test_increment_op.cpp b/test/operators/test_increment_op.cpp index cb1fc9478c65dd70a758dfdf4a3f795470f720af..32f6a57b60c2abd5c6ba76ff646d00909845c3db 100644 --- a/test/operators/test_increment_op.cpp +++ b/test/operators/test_increment_op.cpp @@ -41,8 +41,8 @@ int TestIncrementOp(const std::vector input_shape, int step) { framework::AttributeMap attrs; attrs["step"].Set(step); - auto *op = new operators::IncrementOp("increment", inputs, - outputs, attrs, scope); + auto *op = new operators::IncrementOp( + "increment", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_is_empty_op.cpp b/test/operators/test_is_empty_op.cpp index 5283b2bd69e47ece6d7569f3b68706008c89ef94..9bf9443acdb73d36067ec5394dcf128996fa78d6 100644 --- a/test/operators/test_is_empty_op.cpp +++ b/test/operators/test_is_empty_op.cpp @@ -37,7 +37,7 @@ int TestIsEmptyOp(const std::vector input_shape) { framework::AttributeMap attrs; auto *op = new operators::IsEmptyOp("is_empty", inputs, outputs, - attrs, scope); + attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_less_than_op.cpp b/test/operators/test_less_than_op.cpp index 5c8fa8910d9abfb6ac0a834d9d00274b35fc790b..35f5e6fe74eb16a206185573c7a431f4f2389f77 100644 --- a/test/operators/test_less_than_op.cpp +++ b/test/operators/test_less_than_op.cpp @@ -77,7 +77,7 @@ int TestLessThanOp(const std::vector &x_shape, framework::AttributeMap attrs; attrs["axis"].Set(axis); auto *op = new operators::LessThanOp("less_than", inputs, outputs, - attrs, scope); + attrs, scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_log_op.cpp b/test/operators/test_log_op.cpp index 2f29e8711bb8de0e576a9a1485d96a448ec3d3c0..f0bba93d546258fdc98eb3fe8021880a0f10de6a 100644 --- a/test/operators/test_log_op.cpp +++ b/test/operators/test_log_op.cpp @@ -43,8 +43,8 @@ int TestLogOp(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; - auto *op = - new operators::LogOp("log", inputs, outputs, attrs, scope); + auto *op = new operators::LogOp("log", inputs, outputs, attrs, + scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_logical_and_op.cpp b/test/operators/test_logical_and_op.cpp index 216513cf3d7f64c865bf0931abe6a9dad2d2582d..380b253efe43ce8b4ed262fdf5b971353fcac654 100644 --- a/test/operators/test_logical_and_op.cpp +++ b/test/operators/test_logical_and_op.cpp @@ -50,8 +50,8 @@ int TestLogicalAndOp(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; - auto *op = new operators::LogicalAndOp("logical_and", inputs, - outputs, attrs, scope); + auto *op = new operators::LogicalAndOp( + "logical_and", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_logical_not_op.cpp b/test/operators/test_logical_not_op.cpp index 55d48f79b72a05a74b6d13e4095f25ddfb4e8cbd..8d88362210d6591bb88b04fe0352a44e59761e2d 100644 --- a/test/operators/test_logical_not_op.cpp +++ b/test/operators/test_logical_not_op.cpp @@ -42,8 +42,8 @@ int TestLogicalNotOp(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; - auto *op = new operators::LogicalNotOp("logical_not", inputs, - outputs, attrs, scope); + auto *op = new operators::LogicalNotOp( + "logical_not", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_logical_or_op.cpp b/test/operators/test_logical_or_op.cpp index 593ee35e696ebc392496846d8beb244210d1ec88..9ea555b65b3453b44de57b7bdb8b1926f34e5b66 100644 --- a/test/operators/test_logical_or_op.cpp +++ b/test/operators/test_logical_or_op.cpp @@ -50,8 +50,8 @@ int TestLogicalOrOp(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; - auto *op = new operators::LogicalOrOp("logical_or", inputs, - outputs, attrs, scope); + auto *op = new operators::LogicalOrOp( + "logical_or", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_logical_xor_op.cpp b/test/operators/test_logical_xor_op.cpp index b4ca4c826727a6493ce785f4fab97a4dfa809557..a776de0e8b9ad18455b1b45839673601886ecb0b 100644 --- a/test/operators/test_logical_xor_op.cpp +++ b/test/operators/test_logical_xor_op.cpp @@ -52,8 +52,8 @@ int TestLogicalXorOp(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; - auto *op = new operators::LogicalXorOp("logical_xor", inputs, - outputs, attrs, scope); + auto *op = new operators::LogicalXorOp( + "logical_xor", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_mul_op.cpp b/test/operators/test_mul_op.cpp index 99a2219749c7b16a2dff6a8c78621306f0aad1e6..6ac2c455647a170d9c2f2e19f6dc6403f5d822b1 100644 --- a/test/operators/test_mul_op.cpp +++ b/test/operators/test_mul_op.cpp @@ -54,8 +54,8 @@ int TestMulOP() { AttributeMap attrs; attrs["x_num_col_dims"].Set(1); attrs["y_num_col_dims"].Set(1); - auto *op = - new operators::MulOp("mul", inputs, outputs, attrs, scope); + auto *op = new operators::MulOp("mul", inputs, outputs, attrs, + scope.get()); op->InferShape(); op->Run(); auto output = output_var->template Get(); diff --git a/test/operators/test_multiclass_nms_op.cpp b/test/operators/test_multiclass_nms_op.cpp index 32c2c1f6bd682fdac8d9b81155b8aa044b87232b..782dd6af94c75501272ec09f0bf014e0254456b3 100644 --- a/test/operators/test_multiclass_nms_op.cpp +++ b/test/operators/test_multiclass_nms_op.cpp @@ -55,7 +55,7 @@ class TestMultiClassNMSOp { std::shared_ptr> priorbox = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(priorbox); } } @@ -64,7 +64,7 @@ class TestMultiClassNMSOp { std::shared_ptr predict(const Tensor &t1, const Tensor &t2) { // feed - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *x1_feed_value = scope->Var("box_coder_0.tmp_0"); auto tensor_x1 = x1_feed_value->GetMutable(); tensor_x1->ShareDataWith(t1); diff --git a/test/operators/test_polygon_box_transform_op.cpp b/test/operators/test_polygon_box_transform_op.cpp index 2347f06989153b9ce5994fa0e4d09673ab2698f1..bfd8fb3cc2f1724a585e8698b263e7bed5d268c0 100644 --- a/test/operators/test_polygon_box_transform_op.cpp +++ b/test/operators/test_polygon_box_transform_op.cpp @@ -44,7 +44,7 @@ class TestPolygonBoxTransformOp { op_ptr = std::make_shared< operators::PolygonBoxTransformOp>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(op_ptr); return; } @@ -53,7 +53,7 @@ class TestPolygonBoxTransformOp { } std::shared_ptr predict(const Tensor &t) { - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *input_feed_value = scope->Var(input_var_name); auto tensor_input = input_feed_value->GetMutable(); tensor_input->ShareDataWith(t); diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index c7590512f92e2166ea082986fb97bed771eb2b15..7c7f54c9d0db2725629d32977e4be5b273ab61bc 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -64,7 +64,7 @@ int TestPoolOp(int in_channels, int in_height, int in_width) { attrs["global_pooling"].Set(false); auto *op = new operators::PoolOp("pool2d", inputs, outputs, attrs, - scope); + scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_prior_box_op.cpp b/test/operators/test_prior_box_op.cpp index 424f2443f8627002cff0adc19600f9aba50ad0fb..b2f05a18e6e93efab288b93338fa0fefc710ff84 100644 --- a/test/operators/test_prior_box_op.cpp +++ b/test/operators/test_prior_box_op.cpp @@ -60,7 +60,7 @@ class TestPriorBoxOp { std::shared_ptr> priorbox = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(priorbox); } } @@ -69,7 +69,7 @@ class TestPriorBoxOp { std::shared_ptr predict_priorbox(const Tensor &t1, const Tensor &t2) { // feed - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *x1_feed_value = scope->Var("image"); auto tensor_x1 = x1_feed_value->GetMutable(); tensor_x1->ShareDataWith(t1); diff --git a/test/operators/test_quantize_op.cpp b/test/operators/test_quantize_op.cpp index f3b8fd151c83d115b003b226549ba351188808da..275dedc16901a88bd146ec54ac0b377f77a47312 100644 --- a/test/operators/test_quantize_op.cpp +++ b/test/operators/test_quantize_op.cpp @@ -115,7 +115,7 @@ int TestQuqntizeOp(const int batch_size, const int channel, const int height, framework::AttributeMap attrs; auto *op = new operators::QuantizeOp("quantize", inputs, outputs, - attrs, scope); + attrs, scope.get()); op->InferShape(); op->Run(); diff --git a/test/operators/test_relu6_op.cpp b/test/operators/test_relu6_op.cpp index ceaabcb31343629fa52aef996e3906458b1e011a..fcbaa0ba89129a805809459bf18f4fd30226b13e 100644 --- a/test/operators/test_relu6_op.cpp +++ b/test/operators/test_relu6_op.cpp @@ -45,7 +45,7 @@ int TestRelu6Op(const std::vector input_shape) { framework::AttributeMap attrs; auto *op = new operators::Relu6Op("relu6", inputs, outputs, attrs, - scope); + scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_relu_op.cpp b/test/operators/test_relu_op.cpp index e82e681e3128177300f54326bc346313024644ef..d173845386929981d62e00208e19571a16ae636d 100644 --- a/test/operators/test_relu_op.cpp +++ b/test/operators/test_relu_op.cpp @@ -44,8 +44,8 @@ int TestReluOp(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; - auto *op = - new operators::ReluOp("relu", inputs, outputs, attrs, scope); + auto *op = new operators::ReluOp("relu", inputs, outputs, attrs, + scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_reshape2_op.cpp b/test/operators/test_reshape2_op.cpp index d0d51f984a617ea37713e5830adf6b5d248fb434..69edd34bf64d757d7d963fdeca9519f29bf1b56b 100644 --- a/test/operators/test_reshape2_op.cpp +++ b/test/operators/test_reshape2_op.cpp @@ -60,7 +60,7 @@ class TestReshape2Op { std::shared_ptr> op_ptr = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(op_ptr); return; } @@ -69,7 +69,7 @@ class TestReshape2Op { } std::shared_ptr predict(const Tensor &t) { - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *input_feed_value = scope->Var(input_var_name); auto tensor_input = input_feed_value->GetMutable(); tensor_input->ShareDataWith(t); diff --git a/test/operators/test_sequence_expand_op.cpp b/test/operators/test_sequence_expand_op.cpp index 72e8954f93f4e48524c6b78804237ece427dbae3..731fc8e9e51dd4dc96b9571635ed86e4c42098ec 100644 --- a/test/operators/test_sequence_expand_op.cpp +++ b/test/operators/test_sequence_expand_op.cpp @@ -45,7 +45,7 @@ int TestSequenceExpandOp(const framework::LoDTensor &input_x, attrs["ref_level"].Set(0); auto *op = new operators::SequenceExpandOp( - "sequence_expand", inputs, outputs, attrs, scope); + "sequence_expand", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_sequence_pool_op.cpp b/test/operators/test_sequence_pool_op.cpp index 3b377aa437b8a37041e3f30d299214e19c48ff4e..34c0815b850700d0fdc0ffbedffe96174836d86d 100644 --- a/test/operators/test_sequence_pool_op.cpp +++ b/test/operators/test_sequence_pool_op.cpp @@ -38,8 +38,8 @@ int TestSequencePoolOp(const framework::LoDTensor &input_x, framework::AttributeMap attrs; attrs["pooltype"].SetString(pool_type); - auto *op = new operators::SequencePoolOp("sequence_pool", inputs, - outputs, attrs, scope); + auto *op = new operators::SequencePoolOp( + "sequence_pool", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_sequence_softmax_op.cpp b/test/operators/test_sequence_softmax_op.cpp index 17ab5d7a653a0bb57892ced969ef199a9bde16c3..d8e67f456fb7f8bb98936e27ceecd163f8966824 100644 --- a/test/operators/test_sequence_softmax_op.cpp +++ b/test/operators/test_sequence_softmax_op.cpp @@ -63,7 +63,7 @@ int TestSequenceSoftmaxOp(const std::vector &input_shape, framework::AttributeMap attrs; auto *op = new operators::SequenceSoftmaxOp( - "sequence_softmax", inputs, outputs, attrs, scope); + "sequence_softmax", inputs, outputs, attrs, scope.get()); op->InferShape(); op->Init(); diff --git a/test/operators/test_sigmoid_op.cpp b/test/operators/test_sigmoid_op.cpp index 260dd62781ad18b46e78db3cfaccf1fe27797175..bda7a79d943ef6afa6c57a7b30eeb4ae5a880015 100644 --- a/test/operators/test_sigmoid_op.cpp +++ b/test/operators/test_sigmoid_op.cpp @@ -44,7 +44,7 @@ int TestSigmoidOp(const std::vector input_shape) { framework::AttributeMap attrs; auto *op = new operators::SigmoidOp("sigmoid", inputs, outputs, - attrs, scope); + attrs, scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_softmax_op.cpp b/test/operators/test_softmax_op.cpp index 1cf39e3effeb8761fb105858ed302cac1cf95872..e9ccb260b521f76a82d170715e843bfb151cadb2 100644 --- a/test/operators/test_softmax_op.cpp +++ b/test/operators/test_softmax_op.cpp @@ -65,7 +65,7 @@ int TestSoftmaxOp(const std::vector input_shape) { framework::AttributeMap attrs; auto *op = new operators::SoftmaxOp("softmax", inputs, outputs, - attrs, scope); + attrs, scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_sum_op.cpp b/test/operators/test_sum_op.cpp index 9cabf1212525a7d4d6f36c45f81cba438694843d..225a113f9071962027473e689cd38fab53906647 100644 --- a/test/operators/test_sum_op.cpp +++ b/test/operators/test_sum_op.cpp @@ -46,7 +46,7 @@ class TestSumOp { std::shared_ptr> lrn = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(lrn); } } @@ -55,7 +55,7 @@ class TestSumOp { std::shared_ptr predict_bn(const Tensor &t1, const Tensor &t2) { // feed - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *x1_feed_value = scope->Var("fc_2.tmp_0"); auto tensor_x1 = x1_feed_value->GetMutable(); tensor_x1->ShareDataWith(t1); diff --git a/test/operators/test_tanh_op.cpp b/test/operators/test_tanh_op.cpp index d013b0eedfbe3bdc773e263aad594c89212ad6ce..13dfd09b3bbf2debfdf90a86215e209e75942157 100644 --- a/test/operators/test_tanh_op.cpp +++ b/test/operators/test_tanh_op.cpp @@ -43,8 +43,8 @@ int TestTanhOp(const std::vector input_shape) { auto output_var = scope.get()->Var("output"); framework::AttributeMap attrs; - auto *op = - new operators::TanhOp("tanh", inputs, outputs, attrs, scope); + auto *op = new operators::TanhOp("tanh", inputs, outputs, attrs, + scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_topk_op.cpp b/test/operators/test_topk_op.cpp index 7244232d0737fb9fe77448331c0bdf2477b4f8e5..cf0fde37055bd8c9ed87f5ef1560294ecf865936 100644 --- a/test/operators/test_topk_op.cpp +++ b/test/operators/test_topk_op.cpp @@ -71,8 +71,8 @@ int TestTopKOp(const std::vector input_shape, const int K) { framework::AttributeMap attrs; attrs["k"].Set(K); - auto *op = - new operators::TopKOp("top_k", inputs, outputs, attrs, scope); + auto *op = new operators::TopKOp("top_k", inputs, outputs, attrs, + scope.get()); op->InferShape(); op->Init(); op->Run(); diff --git a/test/operators/test_transpose2_op.cpp b/test/operators/test_transpose2_op.cpp index 5da0faaf119c553e2fb019de76bb40f875f9d673..4c4f5e4c2629b2f4c333f71045ed945c6ee9564a 100644 --- a/test/operators/test_transpose2_op.cpp +++ b/test/operators/test_transpose2_op.cpp @@ -60,7 +60,7 @@ class TestTranspose2Op { std::shared_ptr> op_ptr = std::make_shared>( op->Type(), op->GetInputs(), op->GetOutputs(), - op->GetAttrMap(), program_.scope); + op->GetAttrMap(), program_.scope.get()); ops_of_block_[*block_desc.get()].push_back(op_ptr); return; } @@ -69,7 +69,7 @@ class TestTranspose2Op { } std::shared_ptr predict(const Tensor &t) { - auto scope = program_.scope; + auto scope = program_.scope.get(); Variable *input_feed_value = scope->Var(input_var_name); auto tensor_input = input_feed_value->GetMutable(); tensor_input->ShareDataWith(t); diff --git a/test/test_helper.h b/test/test_helper.h index 852a60839a7aa494072dc63d844c5ea22c003ccf..775a2b8b7b0797ecc637b22539319e8c3e980dae 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -22,7 +22,7 @@ limitations under the License. */ #include "common/common.h" #include "common/log.h" #include "framework/ddim.h" -#include "framework/tensor.h" +#include "framework/lod_tensor.h" static const char *g_ocr = "../models/ocr"; static const char *g_mobilenet_ssd = "../models/mobilenet+ssd"; @@ -72,9 +72,10 @@ static const char *g_test_image_1x3x224x224_vision_mobilenet_input = static const char *g_test_image_1x3x416x416_vision_yolo_input = "../images/yolo_input"; +using namespace paddle_mobile; // NOLINT using paddle_mobile::framework::DDim; +using paddle_mobile::framework::LoDTensor; using paddle_mobile::framework::Tensor; -using namespace paddle_mobile; // NOLINT template void SetupTensor(paddle_mobile::framework::Tensor *input, diff --git a/tools/ci_build.sh b/tools/ci_build.sh index d725afe4595b8e88578ec6c2f0f3c78bc0807a1b..424dc1890f1b7c04863701b1d219e59a4eccb438 100755 --- a/tools/ci_build.sh +++ b/tools/ci_build.sh @@ -64,7 +64,7 @@ function check_ndk() { } function build_android_armv7_cpu_only() { - rm -rf ../build/armeabi-v7a +# rm -rf ../build/armeabi-v7a cmake .. \ -B"../build/armeabi-v7a" \ -DANDROID_ABI="armeabi-v7a with NEON" \ @@ -74,6 +74,7 @@ function build_android_armv7_cpu_only() { -DANDROID_STL=c++_static \ -DANDROID=true \ -DWITH_LOGGING=OFF \ + -DCPU=ON \ -DGPU_MALI=OFF \ -DGPU_CL=OFF \ -DFPGA=OFF @@ -93,6 +94,7 @@ function build_android_armv7_gpu() { -DANDROID_STL=c++_static \ -DANDROID=true \ -DWITH_LOGGING=OFF \ + -DCPU=ON \ -DGPU_MALI=ON \ -DGPU_CL=ON \ -DFPGA=OFF @@ -112,6 +114,7 @@ function build_android_armv8_cpu_only() { -DANDROID_STL=c++_static \ -DANDROID=true \ -DWITH_LOGGING=OFF \ + -DCPU=ON \ -DGPU_MALI=OFF \ -DGPU_CL=OFF \ -DFPGA=OFF @@ -131,6 +134,7 @@ function build_android_armv8_gpu() { -DANDROID_STL=c++_static \ -DANDROID=true \ -DWITH_LOGGING=OFF \ + -DCPU=ON \ -DGPU_MALI=ON \ -DGPU_CL=ON \ -DFPGA=OFF @@ -149,6 +153,7 @@ function build_ios_armv8_cpu_only() { -DIOS_ARCH="${IOS_ARCH}" \ -DIS_IOS=true \ -DUSE_OPENMP=OFF \ + -DCPU=ON \ -DGPU_MALI=OFF \ -DGPU_CL=OFF \ -DFPGA=OFF @@ -167,6 +172,7 @@ function build_ios_armv8_gpu() { -DIOS_ARCH="${IOS_ARCH}" \ -DIS_IOS=true \ -DUSE_OPENMP=OFF \ + -DCPU=ON \ -DGPU_MALI=OFF \ -DGPU_CL=ON \ -DFPGA=OFF @@ -181,6 +187,7 @@ function build_linux_armv7_cpu_only() { -B"../build/armv7_linux" \ -DCMAKE_BUILD_TYPE="MinSizeRel" \ -DCMAKE_TOOLCHAIN_FILE="./tools/toolchains/arm-linux-gnueabihf.cmake" \ + -DCPU=ON \ -DGPU_MALI=OFF \ -DGPU_CL=OFF \ -DFPGA=OFF @@ -195,6 +202,7 @@ function build_linux_armv7_gpu() { -B"../build/armv7_linux" \ -DCMAKE_BUILD_TYPE="MinSizeRel" \ -DCMAKE_TOOLCHAIN_FILE="./tools/toolchains/arm-linux-gnueabihf.cmake" \ + -DCPU=ON \ -DGPU_MALI=ON \ -DGPU_CL=ON \ -DFPGA=OFF diff --git a/tools/op.cmake b/tools/op.cmake index 190bb142bc59a10efdeebeb8a382043440731e68..84b5bb6ef03ae56dfe2aca5bb0ef9859b9be402e 100755 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -310,6 +310,10 @@ if(NOT FOUND_MATCH) set(PROPOSAL_OP ON) set(PSROI_POOL_OP ON) set(ROI_PERSPECTIVE_OP ON) + set(BEAM_SEARCH_OP ON) + set(BEAM_SEARCH_DECODE_OP ON) + set(PAD2D_OP ON) + set(ONE_HOT_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -615,6 +619,12 @@ endif() if (ROI_PERSPECTIVE_OP) add_definitions(-DROI_PERSPECTIVE_OP) endif() +if (BEAM_SEARCH_OP) + add_definitions(-DBEAM_SEARCH_OP) +endif() +if (BEAM_SEARCH_DECODE_OP) + add_definitions(-DBEAM_SEARCH_DECODE_OP) +endif() if (FUSION_DECONVADDBNRELU_OP) add_definitions(-DFUSION_DECONVADDBNRELU_OP) endif() @@ -627,3 +637,6 @@ endif() if (PAD2D_OP) add_definitions(-DPAD2D_OP) endif() +if (ONE_HOT_OP) + add_definitions(-DONE_HOT_OP) +endif()