diff --git a/src/common/types.cpp b/src/common/types.cpp index 6503f6383d22c7342c7446c44fab436810a7c46f..965ab8c4a99d9737efd3b61afb8a4a071c385787 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -56,6 +56,9 @@ const char *G_OP_TYPE_REGION = "region"; const char *G_OP_TYPE_FUSION_CONV_BN = "fusion_conv_bn"; const char *G_OP_TYPE_CONV_TRANSPOSE = "conv2d_transpose"; const char *G_OP_TYPE_PRELU = "prelu"; +const char *G_OP_TYPE_LOOKUP_TABLE = "lookup_table"; +const char *G_OP_TYPE_GRU = "gru"; +const char *G_OP_TYPE_CRF = "crf_decoding"; std::unordered_map< std::string, std::pair, std::vector>> @@ -97,6 +100,11 @@ std::unordered_map< {G_OP_TYPE_FUSION_FC_RELU, {{"X", "Y", "Z"}, {"Out"}}}, {G_OP_TYPE_REGION, {{"X"}, {"Out"}}}, {G_OP_TYPE_FUSION_CONV_BN, {{"Input"}, {"Y"}}}, + {G_OP_TYPE_LOOKUP_TABLE, {{"W", "Ids"}, {"Out"}}}, + {G_OP_TYPE_GRU, + {{"Input", "H0", "Weight", "Bias"}, + {"BatchGate", "BatchResetHiddenPrev", "BatchHidden", "Hidden"}}}, + {G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}}, {G_OP_TYPE_CONV_TRANSPOSE, {{"Input"}, {"Output"}}}}; } // namespace paddle_mobile diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index 765103c241a82ac224d707340f8b66ace827e335..a5ec5e8a333fb6a9ecfc04695a4155213db9e810 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -62,7 +62,7 @@ void OperatorBase::Run() const { vector input_keys = GetInputKeys(); for (const auto key : input_keys) { Tensor *input = GetVarValue(key, inputs_, *scope_); - DLOG << type_ << " input- " << key << "=" << *input; + if (input) DLOG << type_ << " input- " << key << "=" << *input; } vector output_keys = GetOutKeys(); for (const auto key : output_keys) { diff --git a/src/framework/tensor.h b/src/framework/tensor.h index 8d743c0f63eb9f3107603baeab59c41a0f95f1c2..c5572dcbfdbd665994be7ebe005b6c9c98b5bca9 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -339,7 +339,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) { stride = stride > 0 ? stride : 1; #ifndef PADDLE_MOBILE_FPGA for (int i = 0; i < tensor.numel(); i += stride) { - printer << tensor.data()[i] << " "; + // 这不一定是float的 + if (tensor.type() == typeid(float)) { + printer << tensor.data()[i] << " "; + } else if (tensor.type() == typeid(int64_t)) { + printer << tensor.data()[i] << " "; + } } #endif diff --git a/src/io/executor.cpp b/src/io/executor.cpp index 8ef199c4ea35db52d7bc26f4d4268179e765ea17..04df71739e1357ba6963a2c98ae728891b277924 100644 --- a/src/io/executor.cpp +++ b/src/io/executor.cpp @@ -54,8 +54,11 @@ char *Get_binary_data(std::string filename) { #pragma mark - executor template Executor::Executor(const framework::Program p, int batch_size, - bool use_optimize) - : program_(p), batch_size_(batch_size), use_optimize_(use_optimize) { + bool use_optimize, bool loddable) + : program_(p), + batch_size_(batch_size), + use_optimize_(use_optimize), + loddable_(loddable) { if (use_optimize_) { to_predict_program_ = program_.optimizeProgram; } else { @@ -79,7 +82,12 @@ Executor::Executor(const framework::Program p, int batch_size, auto op_base = framework::OpRegistry::CreateOp( op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(), program_.scope); - op_base->InferShape(); + DLOG << "executer in loaddable mode: " << loddable_; + // use pre_infershape to pre resize , but if u use an lod mode tensor u + // need to resize in runtime + if (!loddable_) { + op_base->InferShape(); + } ops_of_block_[*block_desc.get()].push_back(op_base); #ifdef PADDLE_EXECUTOR_MULTITHREAD depManager[i].analysisDep(ops_of_block_[*block_desc.get()]); @@ -225,9 +233,18 @@ void Executor::InitMemory() { delete origin_data; } else { if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { - auto tensor = var->template GetMutable(); - - tensor->template mutable_data(); + DLOG << "var_desc->Name(): " << var_desc->Name(); + DLOG << "var_desc->Tensor_desc().DataType(): " + << var_desc->Tensor_desc().DataType(); + bool is_mute_match; + framework::LoDTensor *tensor = nullptr; + + is_mute_match = varInputMemory(var_desc, var, tensor); + + PADDLE_MOBILE_ENFORCE( + is_mute_match, + "got unhandled var_desc->Tensor_desc().DataType(): %d", + var_desc->Tensor_desc().DataType()); } } } @@ -257,8 +274,18 @@ void Executor::InitCombineMemory() { LoadMemory(*var_desc, tensor, &data); } else { if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { - auto tensor = var->template GetMutable(); - tensor->template mutable_data(); + DLOG << "var_desc->Name(): " << var_desc->Name(); + DLOG << "var_desc->Tensor_desc().DataType(): " + << var_desc->Tensor_desc().DataType(); + bool is_mute_match = false; + framework::LoDTensor *tensor; + + is_mute_match = varInputMemory(var_desc, var, tensor); + + PADDLE_MOBILE_ENFORCE( + is_mute_match, + "got unhandled var_desc->Tensor_desc().DataType(): %d", + var_desc->Tensor_desc().DataType()); } } } @@ -266,6 +293,46 @@ void Executor::InitCombineMemory() { delete origin_data; LOG(kLOG_INFO) << " end init combine memory "; } +template +bool Executor::varInputMemory( + const std::shared_ptr &var_desc, Variable *var, + framework::LoDTensor *tensor) const { + bool is_mute_match = false; + switch (var_desc->Tensor_desc().DataType()) { + case framework::VARTYPE_TYPE_FP16: { + break; + } + + case framework::VARTYPE_TYPE_FP32: { + tensor = var->template GetMutable(); + tensor->template mutable_data(); + is_mute_match = true; + break; + } + + case framework::VARTYPE_TYPE_FP64: { + break; + } + + case framework::VARTYPE_TYPE_INT32: { + break; + } + + case framework::VARTYPE_TYPE_INT64: { + tensor = var->template GetMutable(); + tensor->template mutable_data(); + is_mute_match = true; + break; + } + case framework::VARTYPE_TYPE_BOOL: { + break; + } + + default: { break; } + } + + return is_mute_match; +} template std::shared_ptr Executor::Predict( @@ -278,6 +345,7 @@ std::shared_ptr Executor::Predict( std::shared_ptr to_predict_block = to_predict_program_->Block(0); auto &ops = ops_of_block_[*to_predict_block.get()]; + #ifdef PADDLE_MOBILE_PROFILE std::vector profile(ops.size()); #endif @@ -342,6 +410,7 @@ std::shared_ptr Executor::Predict( clock_gettime(CLOCK_MONOTONIC, &ts); profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; #endif + DLOG << "executer Predict in3.3"; // to Run ops[i]->Run(); @@ -351,6 +420,8 @@ std::shared_ptr Executor::Predict( #endif } #endif + DLOG << "executer Predict in4"; + auto last_op = ops.rbegin(); auto output_map = (*last_op)->Outputs(); @@ -377,6 +448,7 @@ std::shared_ptr Executor::Predict( fprintf(df, "}\n"); fclose(df); #endif + DLOG << "executer Predict in5"; // FILE *pf = fopen("profile.out", "w"); std::unordered_map _tp; @@ -389,6 +461,7 @@ std::shared_ptr Executor::Predict( // pInfo.tid, pInfo.runBegin, pInfo.runEnd, timeCost); } // fclose(pf); + DLOG << "executer Predict in6"; printf("====================[ profile ]======================\n"); using prof_t = std::pair; @@ -409,9 +482,184 @@ std::shared_ptr Executor::Predict( } printf("====================[---------]======================\n"); #endif + DLOG << "executer Predict out"; return std::make_shared(framework::Tensor(*output_tensor)); } + +template +std::shared_ptr Executor::PredictLod( + const framework::LoDTensor &t) { + DLOG << "execute PredictLod :lod" << t.lod(); + + DLOG << "executer Predict in"; + framework::Variable *g_feed_value = program_.scope->Var("feed"); + framework::LoDTensor *feed_tensor = + g_feed_value->GetMutable(); + + DLOG << "executer Predict in2"; + + feed_tensor->Resize(t.dims()); + feed_tensor->ShareDataWith(t); + feed_tensor->set_lod(t.lod()); + DLOG << "feed_tensor .lod : " << feed_tensor->lod(); + + DLOG << "executer Predict in3"; + + std::shared_ptr to_predict_block = + to_predict_program_->Block(0); + DLOG << "executer Predict in3.1"; + + auto &ops = ops_of_block_[*to_predict_block.get()]; + DLOG << "executer Predict in3.2"; + +#ifdef PADDLE_MOBILE_PROFILE + std::vector profile(ops.size()); +#endif +#ifdef PADDLE_EXECUTOR_MULTITHREAD + std::mutex m; + std::condition_variable cv; + std::queue next; + next.push(0); + int rsize = ops.size(); + std::vector status(rsize, 0); + auto &threadPool = ThreadPool::getThreadPool(); + auto &dep = depManager[0]; + auto finishF = [&ops, &m, &cv, &next, &status, &rsize, &dep](int opi) { + std::lock_guard lk(m); + rsize--; + status[opi] = 2; + for (int i : dep.getNext(opi)) { + bool ok = true; + for (int j : dep.getDeps(i)) { + if (status[j] != 2) { + ok = false; + break; + } + } + if (ok && (status[i] == 0)) { + next.push(i); + } + } + cv.notify_one(); + }; + for (;;) { + std::unique_lock lk(m); + cv.wait(lk, [&next, &rsize] { return rsize == 0 || !next.empty(); }); + if (rsize == 0) { + break; + } + while (next.size() > 0) { + int opi = next.front(); + next.pop(); + status[opi] = 1; + threadPool.enqueue([opi, &ops, &finishF, &profile] { + auto &op = ops[opi]; +#ifdef PADDLE_MOBILE_PROFILE + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + profile[opi].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; + profile[opi].tid = ThreadPool::getThreadPoolThreadId(); +#endif + ops[opi]->Run(); +#ifdef PADDLE_MOBILE_PROFILE + clock_gettime(CLOCK_MONOTONIC, &ts); + profile[opi].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; +#endif + finishF(opi); + }); + } + } +#else + for (int i = 0; i < ops.size(); i++) { +#ifdef PADDLE_MOBILE_PROFILE + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; +#endif + DLOG << "executer Predict in3.3 infer"; + if (loddable_) { + ops[i]->InferShape(); + } + + DLOG << "executer Predict in3.3 after infer"; + + // to Run + ops[i]->Run(); +#ifdef PADDLE_MOBILE_PROFILE + clock_gettime(CLOCK_MONOTONIC, &ts); + profile[i].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; +#endif + } +#endif + DLOG << "executer Predict in4"; + + auto last_op = ops.rbegin(); + + auto output_map = (*last_op)->Outputs(); + std::vector out_keys = (*last_op)->GetOutKeys(); + PADDLE_MOBILE_ENFORCE(out_keys.size() > 0, "the last op contains no output"); + framework::LoDTensor *output_tensor = + framework::GetVarValue(out_keys[0], output_map, + *(program_.scope)); +#ifdef PADDLE_MOBILE_PROFILE +#ifdef PADDLE_EXECUTOR_MULTITHREAD + // TODO(haipeng): expose profile info as an interface, user can get them to + // analysis + // the performance of their deepnet. + FILE *df = fopen("net.dot", "w"); + fprintf(df, "digraph {\n"); + for (int i = 0; i < ops.size(); i++) { + for (int j : dep.getNext(i)) { + fprintf(df, "op_%d -> op_%d\n", i, j); + } + } + for (int i = 0; i < ops.size(); i++) { + fprintf(df, "op_%d[label=\"%s (%d)\"]\n", i, ops[i]->Type().c_str(), i); + } + fprintf(df, "}\n"); + fclose(df); +#endif + DLOG << "executer Predict in5"; + + // FILE *pf = fopen("profile.out", "w"); + std::unordered_map _tp; + for (int i = 0; i < profile.size(); i++) { + const auto &pInfo = profile[i]; + uint64_t timeCost = pInfo.runEnd - pInfo.runBegin; + _tp[ops[i]->Type()] += timeCost; + // fprintf(pf, "%d\t%s\t%d\t%llu\t%llu\t%llu\n", i, + // ops[i]->Type().c_str(), + // pInfo.tid, pInfo.runBegin, pInfo.runEnd, timeCost); + } + // fclose(pf); + DLOG << "executer Predict in6"; + + printf("====================[ profile ]======================\n"); + using prof_t = std::pair; + std::vector _tv(_tp.begin(), _tp.end()); + uint64_t _ptotal = 0; + for (auto const &p : _tv) { + _ptotal += p.second; + } + auto compf = [](const prof_t &a, const prof_t &b) { + return a.second > b.second; + }; + std::sort(_tv.begin(), _tv.end(), compf); + _tv.push_back(std::make_pair("total", _ptotal)); + for (auto const &p : _tv) { + printf("%-16s\t%-10.0f\t%-2.4f\n", p.first.c_str(), + static_cast(p.second), + static_cast(p.second) / _ptotal * 100.0); + } + printf("====================[---------]======================\n"); +#endif + DLOG << "executer Predict out"; + + return std::make_shared( + framework::LoDTensor(*output_tensor)); +} + template std::shared_ptr Executor::Predict( const framework::Tensor &t, int block_id) { diff --git a/src/io/executor.h b/src/io/executor.h index f8f2a8ad5657fdb3cf6cb249e32537bd5e866913..6074942c1850934f976840c7b6aad617d96904cd 100644 --- a/src/io/executor.h +++ b/src/io/executor.h @@ -43,13 +43,17 @@ class Executor { * @b 用 loader load 的 program 实例化 executor * */ Executor(const framework::Program p, int batch_size = 1, - bool use_optimize = true); + bool use_optimize = true, bool loddable = false); /* * @b to predict * */ std::shared_ptr Predict(const framework::Tensor &t); - + /* + * @b to predict + * */ + std::shared_ptr PredictLod( + const framework::LoDTensor &t); /* * @b to predict with vector and dim * @@ -73,6 +77,7 @@ class Executor { std::vector>>> ops_of_block_; bool use_optimize_ = false; + bool loddable_ = false; #ifdef PADDLE_EXECUTOR_MULTITHREAD std::vector depManager; #endif @@ -83,6 +88,10 @@ class Executor { uint64_t runEnd = 0UL; }; #endif + + bool varInputMemory(const std::shared_ptr &var_desc, + framework::Variable *var, + framework::LoDTensor *tensor) const; }; } // namespace paddle_mobile diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 420a35c2139626ce1f8d57bb3f7d891a26e2fa9f..a69af82427cacb0cf4a90850168c26ebc717f7aa 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -26,7 +26,8 @@ void PaddleMobile::SetThreadNum(int num) { template bool PaddleMobile::Load(const std::string &dirname, bool optimize, - bool quantification, int batch_size) { + bool quantification, int batch_size, + bool loddable) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -35,7 +36,8 @@ bool PaddleMobile::Load(const std::string &dirname, bool optimize, if (executor_.get() == nullptr) { executor_ = std::make_shared>( - loader_->Load(dirname, optimize, quantification), batch_size, optimize); + loader_->Load(dirname, optimize, quantification), batch_size, optimize, + loddable); } else { LOG(kLOG_INFO) << "executor inited"; } @@ -46,7 +48,8 @@ bool PaddleMobile::Load(const std::string &dirname, bool optimize, template bool PaddleMobile::Load(const std::string &model_path, const std::string ¶_path, bool optimize, - bool quantification, int batch_size) { + bool quantification, int batch_size, + bool loddable) { if (loader_.get() == nullptr) { loader_ = std::make_shared>(); } else { @@ -56,7 +59,7 @@ bool PaddleMobile::Load(const std::string &model_path, if (executor_.get() == nullptr) { executor_ = std::make_shared>( loader_->Load(model_path, para_path, optimize, quantification), - batch_size, optimize); + batch_size, optimize, loddable); } else { LOG(kLOG_INFO) << "executor inited"; } @@ -96,6 +99,12 @@ std::shared_ptr PaddleMobile::Predict( return executor_->Predict(t); } +template +std::shared_ptr PaddleMobile::PredictLod( + const framework::LoDTensor &t) { + return executor_->PredictLod(t); +} + template std::vector::Ptype> PaddleMobile::Predict(const std::vector &input, diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 2617407d0fd0e47e2e9df589c6d750a8b60ca90e..2ea8614cc755a0ffe57147604e6ddd39bdfacb36 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -39,7 +39,8 @@ class PaddleMobile { * @b 加载分开形式的 fluid 模型 * */ bool Load(const std::string &dirname, bool optimize = false, - bool quantification = false, int batch_size = 1); + bool quantification = false, int batch_size = 1, + bool loddable = false); /* * @b load combine format fluid mode @@ -47,7 +48,7 @@ class PaddleMobile { * */ bool Load(const std::string &model_path, const std::string ¶_path, bool optimize = false, bool quantification = false, - int batch_size = 1); + int batch_size = 1, bool loddable = false); /* * @b 设置线程数, 当 cmake 中开启 openmp 时生效 * */ @@ -58,6 +59,11 @@ class PaddleMobile { * */ std::shared_ptr Predict(const framework::Tensor &t); + /* + * @b to predict + * */ + std::shared_ptr PredictLod(const framework::LoDTensor &t); + /* * @b to predict with vector and dim * diff --git a/src/jni/paddle_mobile_jni.cpp b/src/jni/paddle_mobile_jni.cpp index 0da56305f978dc874666a2be26c15a9de47b3757..111ec35def78afc52360f163450ab8003430121b 100644 --- a/src/jni/paddle_mobile_jni.cpp +++ b/src/jni/paddle_mobile_jni.cpp @@ -353,6 +353,41 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv( return result; } +JNIEXPORT jlongArray JNICALL +Java_com_baidu_paddle_PML_predictLod(JNIEnv *env, jclass thiz, jlongArray buf) { + std::lock_guard lock(shared_mutex); + + jlong *ddim_ptr = env->GetLongArrayElements(buf, NULL); + jsize ddim_size = env->GetArrayLength(buf); + std::vector ids; + + for (int i = 0; i < ddim_size; ++i) { + jlong x = ddim_ptr[i]; + ids.push_back((int64_t)x); + } + + paddle_mobile::framework::LoDTensor words; + + auto size = static_cast(ids.size()); + + paddle_mobile::framework::LoD lod{{0, ids.size()}}; + DDim dims{size, 1}; + words.Resize(dims); + words.set_lod(lod); + auto *pdata = words.mutable_data(); + size_t n = words.numel() * sizeof(int64_t); + memcpy(pdata, ids.data(), n); + auto vec_result = paddle_mobile.PredictLod(words); + int count = vec_result->numel(); + jlongArray result = NULL; + ANDROIDLOGE("predict nlp size %d", count); + + result = env->NewLongArray(count); + + env->SetLongArrayRegion(result, 0, count, vec_result->data()); + + return result; +} JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env, jclass thiz, diff --git a/src/operators/crf_op.cpp b/src/operators/crf_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61f9a54352e236a7fcb7b2765ab11055fbec95ab --- /dev/null +++ b/src/operators/crf_op.cpp @@ -0,0 +1,56 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CRF_OP + +#include + +#include "common/enforce.h" +#include "operators/crf_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void CrfOp::InferShape() const { + PADDLE_MOBILE_ENFORCE(this->param_.InputEmission(), + "Input(Emission) should be not null."); + PADDLE_MOBILE_ENFORCE(this->param_.InputTransition(), + "Input(Transition) should be not null."); + PADDLE_MOBILE_ENFORCE(this->param_.outputVBP(), + "Input(ViterbiPath) should be not null."); + + auto emission_dims = this->param_.InputEmission()->dims(); + PADDLE_MOBILE_ENFORCE(emission_dims.size() == 2U, + "The Input(Emission) should be a 2-D tensor."); + PADDLE_MOBILE_ENFORCE(emission_dims[0], + "An empty mini-batch is not allowed."); + + this->param_.outputVBP()->Resize( + {this->param_.InputEmission()->dims()[0], 1}); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(crf_decoding, ops::CrfOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/crf_op.h b/src/operators/crf_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9c966c9077273282bbcb4f25674e8df401956967 --- /dev/null +++ b/src/operators/crf_op.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CRF_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/crf_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using paddle_mobile::framework::Tensor; + +template +class CrfOp : public framework::OperatorWithKernel< + DeviceType, CrfParam, + operators::CrfKernel> { + public: + CrfOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + operators::CrfKernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, CrfParam, + operators::CrfKernel>::OperatorWithKernel; + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(crf_decoding); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/feed_op.h b/src/operators/feed_op.h index dad0880ea69be8449239657af66553065db05321..7982735030690a9d3fe75cbadeb45f0f70a78836 100644 --- a/src/operators/feed_op.h +++ b/src/operators/feed_op.h @@ -35,6 +35,10 @@ class FeedOp : public framework::OperatorBase { auto out_dims = param_.Out()->dims(); out_dims[0] = param_.BatchSize(); param_.Out()->Resize(out_dims); + + // note : mobile infershape iscalled when executer is created. so do not + // pass lod here . + // it is empty } #ifdef PADDLE_MOBILE_FPGA @@ -67,7 +71,10 @@ class FeedOp : public framework::OperatorBase { #else void Init() {} - void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } + void RunImpl() const { + param_.Out()->ShareDataWith(*param_.InputX()); + param_.Out()->set_lod(param_.InputX()->lod()); + } #endif protected: diff --git a/src/operators/gru_op.cpp b/src/operators/gru_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c141cbc06531fabcf5e29546e832480cff850b8c --- /dev/null +++ b/src/operators/gru_op.cpp @@ -0,0 +1,72 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_OP + +#include + +#include "common/enforce.h" +#include "operators/gru_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void GruOp::InferShape() const { + auto lod_size = this->param_.InputInput()->lod().size(); + PADDLE_MOBILE_ENFORCE((lod_size == 1), + "Current LoD only supports one dimension."); + auto input_dims = this->param_.InputInput()->dims(); + auto weight_dims = this->param_.InputWeight()->dims(); + int input_size = input_dims[1]; + int frame_size = weight_dims[0]; + PADDLE_MOBILE_ENFORCE( + (input_size == frame_size * 3), + "The input_size must be 3 times of frame_size in GRUOp."); + PADDLE_MOBILE_ENFORCE( + (weight_dims[1] == frame_size * 3), + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + if (this->param_.InputH0()) { + auto h0_dims = this->param_.InputH0()->dims(); + PADDLE_MOBILE_ENFORCE((h0_dims[1] == frame_size), + "The width of H0 must be equal to frame_size."); + } + if (this->param_.InputBias()) { + auto bias_dims = this->param_.InputBias()->dims(); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_MOBILE_ENFORCE((bias_height == 1), + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_MOBILE_ENFORCE((bias_width == frame_size * 3), + "The shape of Bias must be [1, frame_size * 3]."); + } + this->param_.OutBatchGate()->Resize(input_dims); + this->param_.OutBatchResetHiddenPrev()->Resize({input_dims[0], frame_size}); + this->param_.OutBatchHidden()->Resize({input_dims[0], frame_size}); + this->param_.OutHidden()->Resize({input_dims[0], frame_size}); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(gru, ops::GruOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/gru_op.h b/src/operators/gru_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d348b6c52431f93673f1b772f8c8a9462878cfd5 --- /dev/null +++ b/src/operators/gru_op.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/gru_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using paddle_mobile::framework::Tensor; + +template +class GruOp : public framework::OperatorWithKernel< + DeviceType, GruParam, + operators::GruKernel> { + public: + GruOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + operators::GruKernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, GruParam, + operators::GruKernel>::OperatorWithKernel; + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(gru); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/kernel/arm/concat_kernel.cpp b/src/operators/kernel/arm/concat_kernel.cpp index 33cee24900ed185e62b72895ead83f8170463253..04c590e6b432fbf88cd136eac942485adf9a9003 100644 --- a/src/operators/kernel/arm/concat_kernel.cpp +++ b/src/operators/kernel/arm/concat_kernel.cpp @@ -28,6 +28,7 @@ bool ConcatKernel::Init(ConcatParam *param) { template <> void ConcatKernel::Compute(const ConcatParam ¶m) const { ConcatCompute(param); + param.Out()->set_lod(param.Inputs()[0]->lod()); } } // namespace operators diff --git a/src/operators/kernel/arm/crf_kernel.cpp b/src/operators/kernel/arm/crf_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89769c50a6fc05b28192ebf584ba3cb12f19ac2c --- /dev/null +++ b/src/operators/kernel/arm/crf_kernel.cpp @@ -0,0 +1,39 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CRF_OP + +#include "operators/kernel/crf_kernel.h" +#include "common/types.h" +#include "operators/kernel/central-arm-func/crf_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool CrfKernel::Init(CrfParam *param) { + return true; +} + +template <> +void CrfKernel::Compute(const CrfParam ¶m) const { + CrfCompute(param); +} + +template class CrfKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/elementwise_add_kernel.cpp b/src/operators/kernel/arm/elementwise_add_kernel.cpp index af6067f80c98ab7ce62f14c67021bc67a2bf2a9c..9c6f4a3316385b803a8fdb833490f1fe9e7f41ac 100644 --- a/src/operators/kernel/arm/elementwise_add_kernel.cpp +++ b/src/operators/kernel/arm/elementwise_add_kernel.cpp @@ -29,6 +29,7 @@ template <> void ElementwiseAddKernel::Compute( const ElementwiseAddParam ¶m) const { ElementwiseAddCompute(param); + param.Out()->set_lod(param.InputX()->lod()); } } // namespace operators diff --git a/src/operators/kernel/arm/fusion_fc_kernel.cpp b/src/operators/kernel/arm/fusion_fc_kernel.cpp index 1178c980f97792b9574c16f94528f65a277cda80..d9d112e7a762705efe041c74eea9ddb7d5162918 100644 --- a/src/operators/kernel/arm/fusion_fc_kernel.cpp +++ b/src/operators/kernel/arm/fusion_fc_kernel.cpp @@ -29,6 +29,7 @@ template <> void FusionFcKernel::Compute( const FusionFcParam ¶m) const { FusionFcCompute(param); + param.Out()->set_lod(param.InputX()->lod()); } } // namespace operators diff --git a/src/operators/kernel/arm/gru_kernel.cpp b/src/operators/kernel/arm/gru_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..168471185e07a9c1814c708238996a82c1ee0891 --- /dev/null +++ b/src/operators/kernel/arm/gru_kernel.cpp @@ -0,0 +1,45 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_OP + +#include "operators/kernel/gru_kernel.h" +#include "operators/kernel/central-arm-func/gru_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool GruKernel::Init(GruParam *param) { + return true; +} + +template <> +void GruKernel::Compute(const GruParam ¶m) const { + GruCompute(param); + param.OutHidden()->set_lod(param.InputInput()->lod()); + // DLOG << "________________" << param.OutHidden()->dims(); + // DLOG << "________________" << param.OutHidden()->numel(); + // auto *hiden_data = param.OutHidden()->data(); + // for (int64_t i = 0; i < 10; i++) { + // DLOG << "****************" << hiden_data[i]; + // } +} + +template class GruKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/lookup_kernel.cpp b/src/operators/kernel/arm/lookup_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..584c497c701bd0598e0a151774fe60b7c7fee718 --- /dev/null +++ b/src/operators/kernel/arm/lookup_kernel.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef LOOKUP_OP + +#include "operators/kernel/lookup_kernel.h" +#include "operators/kernel/central-arm-func/lookup_arm_func.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool LookupKernel::Init(LookupParam *param) { + return true; +} + +template <> +void LookupKernel::Compute(const LookupParam ¶m) const { + LookupCompute(param); + param.Out()->set_lod(param.InputIds()->lod()); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/mul_kernel.cpp b/src/operators/kernel/arm/mul_kernel.cpp index e8ef8fe80cd92cd725be8c0251aa4a4f239dbcee..aa3ee7077eb7db440c8493eae5b95f03a42196a4 100644 --- a/src/operators/kernel/arm/mul_kernel.cpp +++ b/src/operators/kernel/arm/mul_kernel.cpp @@ -28,6 +28,7 @@ bool MulKernel::Init(MulParam *param) { template <> void MulKernel::Compute(const MulParam ¶m) const { MulCompute(param); + param.Out()->set_lod(param.InputX()->lod()); } } // namespace operators diff --git a/src/operators/kernel/central-arm-func/crf_arm_func.h b/src/operators/kernel/central-arm-func/crf_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..2cf95081e9678325046d49f86ebf072a14a76795 --- /dev/null +++ b/src/operators/kernel/central-arm-func/crf_arm_func.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CRF_OP +#pragma once + +#include +#include +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { +template +void Decode(const Tensor& emission_weights, const Tensor& transition_weights, + Tensor* decoded_path) { + auto emission_dims = emission_weights.dims(); + const size_t seq_len = emission_dims[0]; + const size_t tag_num = emission_dims[1]; + + const size_t state_trans_base_idx = 2; + + const P* x = emission_weights.data

(); + const P* w = transition_weights.data

(); + int64_t* path = decoded_path->data(); + + // alpha is a memo table. An element alpha(k, v) records the score of the + // best sequence of tags from position 1 to position k with v being the end + // tag. + Tensor alpha; + P* alpha_value = alpha.mutable_data

(emission_dims); + Tensor track; + int* track_value = track.mutable_data(emission_dims); + for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i]; + + for (size_t k = 1; k < seq_len; ++k) { + for (size_t i = 0; i < tag_num; ++i) { + P max_score = -std::numeric_limits

::max(); + int max_j = 0; + for (size_t j = 0; j < tag_num; ++j) { + P score = alpha_value[(k - 1) * tag_num + j] + + w[(j + state_trans_base_idx) * tag_num + i]; + if (score > max_score) { + max_score = score; + max_j = j; + } + } + + alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i]; + track_value[k * tag_num + i] = max_j; + } + } + P max_score = -std::numeric_limits

::max(); + int max_i = 0; + for (size_t i = 0; i < tag_num; ++i) { + P score = alpha_value[(seq_len - 1) * tag_num + i] + w[tag_num + i]; + if (score > max_score) { + max_score = score; + max_i = i; + } + } + path[seq_len - 1] = max_i; + for (int k = seq_len - 1; k >= 1; --k) { + path[k - 1] = max_i = track_value[k * tag_num + max_i]; + } +} +template +void CrfCompute(const CrfParam& param) { + auto* emission = param.InputEmission(); + auto* transition = param.InputTransition(); + auto* label = param.InputLabel(); + auto* decoded_path = param.outputVBP(); + // DLOG<<*emission; + // DLOG<<*transition; + // DLOG<<*label; + + PADDLE_MOBILE_ENFORCE(emission->NumLevels() == 1U, + "The Input(Emission) should be a sequence."); + auto lod = emission->lod(); + PADDLE_MOBILE_ENFORCE(lod.size(), + "The Input(Emission) should be a sequence."); + const size_t level = 0; + const size_t seq_num = lod[level].size() - 1; + int64_t* path = decoded_path->mutable_data(); + int numel = decoded_path->numel(); + memset(static_cast(path), 0, sizeof(int64_t) * numel); + for (size_t i = 0; i < seq_num; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + Tensor decoded_path_one_seq = decoded_path->Slice(start_pos, end_pos); + Decode

(emission->Slice(start_pos, end_pos), *transition, + &decoded_path_one_seq); + } + if (label) { + PADDLE_MOBILE_ENFORCE(label->NumLevels() == 1U, + "The Input(Label) should be a sequence."); + const int64_t* label_value = label->data(); + size_t batch_size = emission->dims()[0]; + for (size_t i = 0; i < batch_size; ++i) { + path[i] = label_value[i] == path[i] ? 1 : 0; + } + } +} +} // namespace operators + +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h index 0d8d793cccf1b8de596bffa023ba367fb1b46155..42c01d2825e052a52e7021a1b2a97997fb9c915b 100644 --- a/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h +++ b/src/operators/kernel/central-arm-func/fusion_fc_arm_func.h @@ -30,6 +30,9 @@ void FusionFcCompute(const FusionFcParam ¶m) { int axis = param.Axis(); Tensor *out = param.Out(); auto *out_data = out->mutable_data(); + // int m = out->dims()[0]; + // int n = out->dims()[1]; + const Tensor x_matrix = input_x->dims().size() > 2 ? framework::ReshapeToMatrix(*input_x, param.XNumColDims()) @@ -57,6 +60,7 @@ void FusionFcCompute(const FusionFcParam ¶m) { // for (int i = 0; i < out->numel(); i++) { // DLOG << out_data[i]; // } + // bias_data的维度和out的维度一致 math::matmul(x_matrix, false, y_matrix, false, static_cast(1), out, static_cast(1), false); PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2."); diff --git a/src/operators/kernel/central-arm-func/gru_arm_func.h b/src/operators/kernel/central-arm-func/gru_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..2e00e839ff10da0d40612c9f63d5d0f7e059a0fe --- /dev/null +++ b/src/operators/kernel/central-arm-func/gru_arm_func.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_OP +#pragma once + +#include +#include +#include "common/types.h" +#include "operators/math/gru_compute.h" +#include "operators/math/math_function.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using LoDTensor = framework::LoDTensor; +using Tensor = framework::Tensor; + +template +inline void ReorderInitState(const framework::Tensor& src, + std::vector index_lod, + framework::Tensor* dst, bool indexed_src) { + math::CopyMatrixRowsFunctor row_shuffle; + dst->mutable_data(src.dims()); + row_shuffle(src, index_lod, dst, indexed_src); +} +template +void GruCompute(const GruParam& param) { + auto* input = param.InputInput(); + auto* h0 = param.InputH0(); + auto* weight = param.InputWeight(); + const auto* weight_data = weight->data(); + auto* bias = param.InputBias(); + auto* batch_gate = param.OutBatchGate(); + batch_gate->mutable_data(); + auto* batch_reset_hidden_prev = param.OutBatchResetHiddenPrev(); + batch_reset_hidden_prev->mutable_data(); + auto* batch_hidden = param.OutBatchHidden(); + batch_hidden->mutable_data(); + auto* hidden = param.OutHidden(); + hidden->mutable_data(); + + auto hidden_dims = hidden->dims(); + + bool is_reverse = param.IsReverse(); + math::LoDTensor2BatchFunctor to_batch; + to_batch(*input, batch_gate, true, is_reverse); + // math::ClearTensor clearTensor; + // clearTensor(batch_gate); + if (bias) { + math::RowwiseAdd add_bias; + add_bias(*batch_gate, *bias, batch_gate); + } + int frame_size = hidden_dims[1]; + math::GRUMetaValue gru_value; + gru_value.gate_weight = const_cast(weight_data); + gru_value.state_weight = + const_cast(weight_data + 2 * frame_size * frame_size); + Tensor ordered_h0; + std::vector order(batch_gate->lod()[2]); + if (h0) { + // Since the batch computing for GRU reorders the input sequences + // according to their length. The initialized cell state also needs + // to reorder. + ReorderInitState(*h0, order, &ordered_h0, true); + gru_value.prev_out_value = ordered_h0.data(); + } else { + gru_value.prev_out_value = nullptr; + } + auto batch_starts = batch_gate->lod()[0]; + size_t seq_len = batch_starts.size() - 1; + auto active_node = math::GetActivationType(param.Activation()); + auto active_gate = math::GetActivationType(param.GateActivation()); + for (size_t n = 0; n < seq_len; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + Tensor gate_t = batch_gate->Slice(bstart, bend); // BUG + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.output_value = hidden_t.data(); + gru_value.gate_value = gate_t.data(); + gru_value.reset_output_value = reset_hidden_prev_t.data(); + + math::GRUUnitFunctor::compute( + gru_value, frame_size, cur_batch_size, active_node, active_gate); + + gru_value.prev_out_value = gru_value.output_value; + } + math::Batch2LoDTensorFunctor to_seq; + batch_hidden->set_lod(batch_gate->lod()); + to_seq(*batch_hidden, hidden); +} + +} // namespace operators + +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/central-arm-func/lookup_arm_func.h b/src/operators/kernel/central-arm-func/lookup_arm_func.h new file mode 100644 index 0000000000000000000000000000000000000000..917973822f90b5015ea6b49aef0b7437ce8988e1 --- /dev/null +++ b/src/operators/kernel/central-arm-func/lookup_arm_func.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef LOOKUP_OP +#pragma once + +#include +#include "framework/ddim.h" +#include "operators/op_param.h" + +constexpr int64_t kNoPadding = -1; + +namespace paddle_mobile { +namespace operators { + +template +void LookupCompute(const LookupParam ¶m) { + auto *ids_t = param.InputIds(); + auto *table_t = param.InputW(); + auto *output_t = param.Out(); + int64_t padding_idx = param.PaddingIdx(); + const framework::DDim &table_dim = table_t->dims(); + int64_t ids_numel; + const auto *ids = ids_t->data(); + ids_numel = ids_t->numel(); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + auto *table = table_t->data(); + auto *output = output_t->mutable_data(); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(float)); + } else { + PADDLE_MOBILE_ENFORCE(ids[i] < row_number, + "look uptable ids[i] = 0, + "lookuptable ids[i] >= 0 check failed"); + + memcpy(output + i * row_width, table + ids[i] * row_width, + row_width * sizeof(float)); + } + } +} +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/crf_kernel.h b/src/operators/kernel/crf_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..71c07cf0384d482522de3a6652c6d24a22af656a --- /dev/null +++ b/src/operators/kernel/crf_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef CRF_OP + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class CrfKernel + : public framework::OpKernelBase> { + public: + void Compute(const CrfParam& param) const; + bool Init(CrfParam* param); +}; +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/gru_kernel.h b/src/operators/kernel/gru_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6b02663bd0e2982bdb2480c54632d2a8da9f67fc --- /dev/null +++ b/src/operators/kernel/gru_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef GRU_OP + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class GruKernel + : public framework::OpKernelBase> { + public: + void Compute(const GruParam& param) const; + bool Init(GruParam* param); +}; +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/lookup_kernel.h b/src/operators/kernel/lookup_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..73f6cfcced078382b40526eae1f6560d7d168b97 --- /dev/null +++ b/src/operators/kernel/lookup_kernel.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef LOOKUP_OP + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class LookupKernel + : public framework::OpKernelBase> { + public: + void Compute(const LookupParam& param) const; + bool Init(LookupParam* param); +}; +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/lookup_op.cpp b/src/operators/lookup_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..33f2b434adaec19acd36aab0d5157138ebd3e91e --- /dev/null +++ b/src/operators/lookup_op.cpp @@ -0,0 +1,67 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef LOOKUP_OP + +#include + +#include "common/enforce.h" +#include "operators/lookup_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void LookupOp::InferShape() const { + PADDLE_MOBILE_ENFORCE(this->param_.InputW() != nullptr, + "Input(W) of LookupTableOp should not be null."); + auto *ids_t = this->param_.InputIds(); + + PADDLE_MOBILE_ENFORCE(ids_t != nullptr, + "Input(Ids) of LookupTableOp should not be null."); + PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr, + "Output(Out) of LookupTableOp should not be null."); + // this->param__.InputW()-> + + auto table_dims = this->param_.InputW()->dims(); + auto ids_dims = ids_t->dims(); + + int ids_rank = ids_dims.size(); + + PADDLE_MOBILE_ENFORCE(table_dims.size() == 2, + "table_dims.size()==2 check failed"); + + PADDLE_MOBILE_ENFORCE(ids_dims[ids_rank - 1] == 1, + "The last dimension of the 'Ids' tensor must be 1."); + + auto output_dims = + framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1)); + output_dims.push_back(table_dims[1]); + + this->param_.Out()->Resize(framework::make_ddim(output_dims)); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(lookup_table, ops::LookupOp); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/lookup_op.h b/src/operators/lookup_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9c9d03c8d10e9b01ad958c12d31a49908075eb27 --- /dev/null +++ b/src/operators/lookup_op.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef LOOKUP_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/lookup_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using paddle_mobile::framework::Tensor; + +template +class LookupOp : public framework::OperatorWithKernel< + DeviceType, LookupParam, + operators::LookupKernel> { + public: + LookupOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const framework::AttributeMap &attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel, + operators::LookupKernel>( + type, inputs, outputs, attrs, scope) {} + + using framework::OperatorWithKernel< + DeviceType, LookupParam, + operators::LookupKernel>::OperatorWithKernel; + void InferShape() const override; +}; + +} // namespace operators +} // namespace paddle_mobile + +#ifdef PADDLE_MOBILE_CPU +USE_OP_CPU(lookup_table); +#endif +#ifdef PADDLE_MOBILE_MALI_GPU +#endif +#ifdef PADDLE_MOBILE_FPGA +#endif + +#endif diff --git a/src/operators/math/activation_functions.h b/src/operators/math/activation_functions.h new file mode 100644 index 0000000000000000000000000000000000000000..8604065a2570cc17c970c487fcaa898f78c72a85 --- /dev/null +++ b/src/operators/math/activation_functions.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "common/enforce.h" +namespace paddle_mobile { +namespace operators { +namespace math { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +enum ActivationType { + kSigmoid, + kReLU, + kTanh, + kIdentity, +}; + +inline ActivationType GetActivationType(const std::string &type) { + if (type == "sigmoid") { + return ActivationType::kSigmoid; + } else if (type == "relu") { + return ActivationType::kReLU; + } else if (type == "tanh") { + return ActivationType::kTanh; + } else if (type == "identity" || type == "") { + return ActivationType::kIdentity; + } + PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type."); +} + +namespace forward { + +template +T Identity(const T a) { + return a; +} + +template +T Relu(const T a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +template +T Sigmoid(const T a) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + T tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +template +T Tanh(const T a) { + T tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +} // namespace forward + +template +struct Active { + typedef T (*Act)(T); +}; + +static Active::Act kActFloat[] = { + &forward::Sigmoid, &forward::Relu, &forward::Tanh, + &forward::Identity}; + +namespace forward { +inline float activation(float a, int index) { return kActFloat[index](a); } + +} // namespace forward + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/gru_compute.cpp b/src/operators/math/gru_compute.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2f71ec3a34d83cd65626c671ace41ae071c95ce2 --- /dev/null +++ b/src/operators/math/gru_compute.cpp @@ -0,0 +1,55 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef GRU_OP +#include "operators/math/gru_compute.h" +#include "common/types.h" +#include "operators/math/activation_functions.h" +#include "operators/math/gemm.h" +#include "operators/math/gru_cpu_kernel.h" +#include "operators/math/gru_kernel.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +struct GRUUnitFunctor { + static void compute(GRUMetaValue value, int frame_size, int batch_size, + const ActivationType active_node, + const ActivationType active_gate) { + if (value.prev_out_value) { + Sgemm(batch_size, frame_size * 2, frame_size, 1, value.prev_out_value, + frame_size, value.gate_weight, frame_size * 2, 1, value.gate_value, + frame_size * 3, false, nullptr); + } + + forward_reset_output(forward::gru_resetOutput(), value, frame_size, + batch_size, active_gate); + + if (value.prev_out_value) { + Sgemm(batch_size, frame_size, frame_size, 1, value.reset_output_value, + frame_size, value.state_weight, frame_size, 1, + value.gate_value + frame_size * 2, frame_size * 3, false, nullptr); + } + + forward_final_output(forward::gru_finalOutput(), value, frame_size, + batch_size, active_node); + } +}; + +template struct GRUUnitFunctor; +} // namespace math +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/math/gru_compute.h b/src/operators/math/gru_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..89cac1b8e49cd11eec551ba60f54e72f3912c846 --- /dev/null +++ b/src/operators/math/gru_compute.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef GRU_OP +#pragma once + +#include "operators/math/activation_functions.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +struct GRUMetaValue { + T *gate_weight; + T *state_weight; + T *gate_value; + T *reset_output_value; + T *output_value; + T *prev_out_value; +}; + +template +struct GRUUnitFunctor { + static void compute(GRUMetaValue value, int frame_size, int batch_size, + const ActivationType active_node, + const ActivationType active_gate); +}; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/math/gru_cpu_kernel.h b/src/operators/math/gru_cpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ea24c4f1d97ebfbc5454e118121a3c79f28008c6 --- /dev/null +++ b/src/operators/math/gru_cpu_kernel.h @@ -0,0 +1,116 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef GRU_OP +#pragma once +#include +#include "operators/math/activation_functions.h" +#include "operators/math/gru_compute.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, + T *gate_value, T *reset_output_value, + T *prev_output_value, int frame_size, + ActivationType active_gate) { + T r_value_update_gate; + T r_value_reset_gate; + T r_value_reset_output; + T r_prev_out = 0; + T *update_gate = gate_value; + T *reset_gate = gate_value + frame_size; + + for (int i = 0; i < frame_size; i++) { + r_value_update_gate = update_gate[i]; + r_value_reset_gate = reset_gate[i]; + if (prev_output_value) { + r_prev_out = prev_output_value[i]; + } + + op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, + &r_value_reset_output, active_gate); + + update_gate[i] = r_value_update_gate; + reset_gate[i] = r_value_reset_gate; + reset_output_value[i] = r_value_reset_output; + } +} + +template +void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, + T *gate_value, T *prev_output_value, + T *output_value, int frame_size, + ActivationType active_node) { + T r_value_update_gate; + T r_value_frame_state; + T r_prev_out = 0; + T r_output; + T *update_gate = gate_value; + T *frame_state = gate_value + frame_size * 2; + + for (int i = 0; i < frame_size; i++) { + r_value_update_gate = update_gate[i]; + r_value_frame_state = frame_state[i]; + if (prev_output_value) { + r_prev_out = prev_output_value[i]; + } + + op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, + &r_output, active_node); + + frame_state[i] = r_value_frame_state; + output_value[i] = r_output; + } +} + +template +inline void forward_reset_output(OpResetOutput op_reset_output, + GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_gate) { + for (int b = 0; b < batch_size; b++) { + hl_naive_gru_forward_reset_output( + op_reset_output, value.gate_value, value.reset_output_value, + value.prev_out_value, frame_size, active_gate); + + value.gate_value += frame_size * 3; + value.reset_output_value += frame_size; + if (value.prev_out_value) { + value.prev_out_value += frame_size; + } + } +} + +template +inline void forward_final_output(OpFinalOutput op_final_output, + GRUMetaValue value, int frame_size, + int batch_size, ActivationType active_node) { + for (int b = 0; b < batch_size; b++) { + hl_naive_gru_forward_final_output(op_final_output, value.gate_value, + value.prev_out_value, value.output_value, + frame_size, active_node); + + value.gate_value += frame_size * 3; + value.output_value += frame_size; + if (value.prev_out_value) { + value.prev_out_value += frame_size; + } + } +} + +} // namespace math +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/math/gru_kernel.h b/src/operators/math/gru_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..6113ce8da997eaa5720886d637a9cc9261ea5227 --- /dev/null +++ b/src/operators/math/gru_kernel.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef GRU_OP +#pragma once +#include +#include "operators/math/activation_functions.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +namespace forward { + +template +class gru_resetOutput { + public: + void operator()(T *value_update_gate, T *value_reset_gate, T *prev_out, + T *value_reset_output, ActivationType act_gate) { + *value_update_gate = activation(*value_update_gate, act_gate); + *value_reset_gate = activation(*value_reset_gate, act_gate); + *value_reset_output = (*prev_out) * (*value_reset_gate); + } +}; + +template +class gru_finalOutput { + public: + void operator()(T *value_update_gate, T *value_frame_state, T *prev_out, + T *value_output, ActivationType act_input) { + *value_frame_state = activation(*value_frame_state, act_input); + *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) + + ((*value_update_gate) * (*value_frame_state)); + } +}; +} // namespace forward + +} // namespace math +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/math/math_function.cpp b/src/operators/math/math_function.cpp index 576b06422cd0665d9e211633ce2f559e73c11fb5..1ef06372292cd2e8311dfd25ae84b22be03676cd 100644 --- a/src/operators/math/math_function.cpp +++ b/src/operators/math/math_function.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "operators/math/math_function.h" +#include #include "operators/math/gemm.h" namespace paddle_mobile { @@ -119,6 +120,40 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, #endif } +template +struct ClearTensor { + void operator()(framework::Tensor *tensor) { + auto size = tensor->numel(); + auto *tensor_data = tensor->data(); + memset((void *)tensor_data, 0, sizeof(T) * size); + } +}; + +template +struct RowwiseAdd { + void operator()(const framework::Tensor &input, + const framework::Tensor &vector, framework::Tensor *output) { + auto in_dims = input.dims(); + auto size = input.numel() / in_dims[0]; + PADDLE_MOBILE_ENFORCE((vector.numel() == size), + "vector.numel() must be equal to size."); + PADDLE_MOBILE_ENFORCE((output->dims() == in_dims), + "output->dims() must be equal to in_dims."); + + auto *input_data = input.data(); + auto *out_data = output->data(); + auto *vec_data = vector.data(); + for (int64_t i = 0; i < in_dims[0]; ++i) { + for (int64_t j = 0; j < size; ++j) { + out_data[i * size + j] = input_data[i * size + j] + vec_data[j]; + } + } + } +}; + +template struct RowwiseAdd; +template struct ClearTensor; + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/math_function.h b/src/operators/math/math_function.h index 8d97f8628fb4f71cdd7664161983225136ec7c7f..de19e3df2ab69c8ac490b09af2852bf2fa806c64 100644 --- a/src/operators/math/math_function.h +++ b/src/operators/math/math_function.h @@ -38,6 +38,17 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a, const framework::Tensor &matrix_b, bool trans_b, framework::Tensor *matrix_out, float *p, std::string mode, float *bias, float *bias1); +template +struct ClearTensor { + void operator()(framework::Tensor *tensor); +}; + +template +struct RowwiseAdd { + void operator()(const framework::Tensor &input, const framework::Tensor &vec, + framework::Tensor *output); +}; + } // namespace math } // namespace operators } // namespace paddle_mobile diff --git a/src/operators/math/sequence2batch.cpp b/src/operators/math/sequence2batch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..097a258dddd513294cd1c1d2f4c9ddb0dd530052 --- /dev/null +++ b/src/operators/math/sequence2batch.cpp @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/math/sequence2batch.h" +#include +#include "common/types.h" + +namespace paddle_mobile { +namespace operators { +namespace math { + +template +class CopyMatrixRowsFunctor { + public: + void operator()(const framework::Tensor& src, std::vector index_lod, + framework::Tensor* dst, bool is_src_index) { + size_t* index = index_lod.data(); + auto src_dims = src.dims(); + auto dst_dims = dst->dims(); + PADDLE_MOBILE_ENFORCE((src_dims.size() == 2UL), + "The src must be matrix with rank 2."); + PADDLE_MOBILE_ENFORCE((dst_dims.size() == 2UL), + "The dst must be matrix with rank 2."); + PADDLE_MOBILE_ENFORCE((src_dims[1] == dst_dims[1]), + "The width of src and dst must be same."); + auto height = dst_dims[0]; + auto width = dst_dims[1]; + auto* src_data = src.data(); + auto* dst_data = dst->data(); + for (int i = 0; i < height; ++i) { + if (is_src_index) { + memcpy(dst_data + i * width, src_data + index[i] * width, + width * sizeof(T)); + } else { + memcpy(dst_data + index[i] * width, src_data + i * width, + width * sizeof(T)); + } + } + } +}; + +template class CopyMatrixRowsFunctor; + +template class LoDTensor2BatchFunctor; +template class Batch2LoDTensorFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/math/sequence2batch.h b/src/operators/math/sequence2batch.h new file mode 100644 index 0000000000000000000000000000000000000000..42b369f7dc48718846b7d8e039b876693f9770df --- /dev/null +++ b/src/operators/math/sequence2batch.h @@ -0,0 +1,169 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "framework/lod_tensor.h" +#include "framework/tensor.h" + +namespace paddle_mobile { +namespace operators { +namespace math { +template +class CopyMatrixRowsFunctor { + public: + // If is_src_index is true, + // copy the indexed rows of input src to the output dst. + // If is_src_index is false, + // copy the input src to the indexed rows of output dst. + // The indexed rows are based on the input index. + void operator()(const framework::Tensor& src, std::vector index_lod, + framework::Tensor* dst, bool is_src_index); +}; + +template +class LoDTensor2BatchFunctor { + // Calculate the length of each sequence and + // sort sequence index by the length. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} + // + struct SeqInfo { + SeqInfo(int start, int length, int seq_idx) + : start(start), length(length), seq_idx(seq_idx) {} + int start; + int length; + int seq_idx; + }; + + public: + void operator()(const framework::LoDTensor& lod_tensor, + framework::LoDTensor* batch, bool is_cal_batch_lod, + bool is_reverse = false) { + if (!is_cal_batch_lod) { + auto lods = batch->lod(); + PADDLE_MOBILE_ENFORCE( + (lods.size() > 2UL), + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_MOBILE_ENFORCE( + (lods[1].size() == static_cast(lod_tensor.dims()[0])), + "The LoD information should be consistent with the dims."); + CopyMatrixRowsFunctor to_batch; + to_batch(lod_tensor, lods[1], batch, true); + return; + } + + auto lods = lod_tensor.lod(); + PADDLE_MOBILE_ENFORCE((lods.size() == 1UL), + "Only support one level sequence now."); + + const auto& lod = lods[0]; + + std::vector seq_info; + for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { + int length = lod[seq_id + 1] - lod[seq_id]; + seq_info.emplace_back(lod[seq_id], length, seq_id); + } + + std::sort(seq_info.begin(), seq_info.end(), + [](SeqInfo a, SeqInfo b) { return a.length > b.length; }); + + // Calculate the start position of each batch. + // example: sequences = {s0, s1, s2} + // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 + // num_batch = 5, + // batchIndex = {b0, b1, b2, b3, b4} + // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 + // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} + // batch_start_positions[0] = len(b0) + // batch_start_positions[1] = len(b0) + len(b1) + // batch_start_positions[2] = len(b0) + len(b1) + len(b2) + // ... + // seq2batch_idx[12] = {4, 0, 9, + // 5, 1, 10, + // 6, 2, 11, + // 7, 3, + // 8} + // seq_order = {1, 0, 2}, the sort order. + // where 1 is the second sequence, + // 0 is the first sequence, + // 2 is the third sequence. + // The num_batch represents batch size after rearranging the + // input LodTensor. It is also the maximum length of input sequence. + + framework::LoD batch_lods; + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + batch_lods.emplace_back(std::vector{0}); + + // batch_lods[0] is the start positions for batch LoDTensor + int num_batch = seq_info[0].length; + batch_lods[0].resize(static_cast(num_batch + 1)); + // batch_lods[1] is the raw index in the input LoDTensor + batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); + // batch_lods[2] is the sort order for the input LoDTensor. + batch_lods[2].resize(seq_info.size()); + + size_t* batch_starts = batch_lods[0].data(); + size_t* seq2batch_idx = batch_lods[1].data(); + batch_starts[0] = 0; + for (int n = 0; n < num_batch; n++) { + auto batch_id = static_cast(batch_starts[n]); + for (size_t i = 0; i < seq_info.size(); ++i) { + int seq_len = seq_info[i].length; + int start = seq_info[i].start; + if (n < seq_len) { + seq2batch_idx[batch_id] = + is_reverse ? start + seq_len - 1 - n : start + n; + batch_id++; + } else { + break; + } + } + batch_starts[n + 1] = static_cast(batch_id); + } + size_t* seq_order = batch_lods[2].data(); + for (size_t i = 0; i < seq_info.size(); ++i) { + seq_order[i] = seq_info[i].seq_idx; + } + batch->set_lod(batch_lods); + + CopyMatrixRowsFunctor to_batch; + to_batch(lod_tensor, batch_lods[1], batch, true); + } +}; + +template +class Batch2LoDTensorFunctor { + public: + void operator()(const framework::LoDTensor& batch, + framework::LoDTensor* lod_tensor) { + auto in_lod = batch.lod(); + PADDLE_MOBILE_ENFORCE( + (in_lod.size() > 2UL), + "The LoD of LoDTensor should inlcude at least 2-level " + "sequence information."); + PADDLE_MOBILE_ENFORCE( + (in_lod[1].size() == static_cast(lod_tensor->dims()[0])), + "The LoD information should be consistent with the dims."); + CopyMatrixRowsFunctor to_seq; + to_seq(batch, in_lod[1], lod_tensor, false); + } +}; +} // namespace math +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/op_param.h b/src/operators/op_param.h index a6077812a0a4f56b58e666617e880b91f7c19b97..10fe5c2494bb3e5ddbb6876525db8017fe0c910c 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -73,6 +73,10 @@ struct DtypeTensorTrait { class OpParam { protected: + template + static T *InputH0From(const VariableNameMap &inputs, const Scope &scope) { + return GetVarValue("H0", inputs, scope); + } template static T *InputAlphaFrom(const VariableNameMap &inputs, const Scope &scope) { return GetVarValue("Alpha", inputs, scope); @@ -87,6 +91,33 @@ class OpParam { static T *InputXFrom(const VariableNameMap &inputs, const Scope &scope) { return GetVarValue("X", inputs, scope); } + + template + static T *InputWFrom(const VariableNameMap &inputs, const Scope &scope) { + return GetVarValue("W", inputs, scope); + } + + template + static T *InputIdsFrom(const VariableNameMap &inputs, const Scope &scope) { + return GetVarValue("Ids", inputs, scope); + } + + template + static T *InputEmissionFrom(const VariableNameMap &inputs, + const Scope &scope) { + return GetVarValue("Emission", inputs, scope); + } + + template + static T *InputTransitionFrom(const VariableNameMap &inputs, + const Scope &scope) { + return GetVarValue("Transition", inputs, scope); + } + template + static T *InputLabelFrom(const VariableNameMap &inputs, const Scope &scope) { + return GetVarValue("Label", inputs, scope); + } + template static T *InputXFrom1(const VariableNameMap &inputs, const Scope &scope) { return GetVarValue1("addX", inputs, scope); @@ -112,6 +143,10 @@ class OpParam { return GetVarValue("Bias", inputs, scope); } template + static T *InputWeightFrom(const VariableNameMap &inputs, const Scope &scope) { + return GetVarValue("Weight", inputs, scope); + } + template static T *InputVarianceFrom(const VariableNameMap &inputs, const Scope &scope) { return GetVarValue("Variance", inputs, scope); @@ -166,6 +201,35 @@ class OpParam { return GetMultiVarValue("X", inputs, scope); } + template + static T *OutputBatchGateFrom(const VariableNameMap &outputs, + const Scope &scope) { + return GetVarValue("BatchGate", outputs, scope); + } + + template + static T *OutputViterbiPathFrom(const VariableNameMap &outputs, + const Scope &scope) { + return GetVarValue("ViterbiPath", outputs, scope); + } + template + static T *OutputBatchResetHiddenPrevFrom(const VariableNameMap &outputs, + const Scope &scope) { + return GetVarValue("BatchResetHiddenPrev", outputs, scope); + } + + template + static T *OutputBatchHiddenFrom(const VariableNameMap &outputs, + const Scope &scope) { + return GetVarValue("BatchHidden", outputs, scope); + } + + template + static T *OutputHiddenFrom(const VariableNameMap &outputs, + const Scope &scope) { + return GetVarValue("Hidden", outputs, scope); + } + template static T *OutputFrom(const VariableNameMap &outputs, const Scope &scope) { return GetVarValue("Output", outputs, scope); @@ -326,18 +390,18 @@ class ElementwiseAddParam : OpParam { axis_ = GetAttr("axis", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - const RType *InputY() const { return input_y_; } + const GType *InputY() const { return input_y_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const int &Axis() const { return axis_; } private: - RType *input_x_; - RType *input_y_; - RType *out_; + GType *input_x_; + GType *input_y_; + GType *out_; int axis_; #ifdef PADDLE_MOBILE_FPGA @@ -371,20 +435,20 @@ class MulParam : OpParam { y_num_col_dims_ = GetAttr("y_num_col_dims", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } - const RType *InputY() const { return input_y_; } + const GType *InputY() const { return input_y_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const int &XNumColDims() const { return x_num_col_dims_; } const int &YNumColDims() const { return y_num_col_dims_; } private: - RType *input_x_; - RType *input_y_; - RType *out_; + GType *input_x_; + GType *input_y_; + GType *out_; int x_num_col_dims_; int y_num_col_dims_; }; @@ -406,13 +470,13 @@ class ConcatParam : public OpParam { vector Inputs() const { return inputs_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const int &Axis() const { return axis_; } private: vector inputs_; - RType *out_; + GType *out_; int axis_; }; #endif @@ -797,13 +861,13 @@ class FeedParam : public OpParam { auto var = scope->Var("batch_size"); batch_size = var->GetValue(); } - const RType *InputX() const { return input_x_; } - RType *Out() const { return out_; } + const GType *InputX() const { return input_x_; } + GType *Out() const { return out_; } const int BatchSize() const { return batch_size; } private: - RType *input_x_; - RType *out_; + GType *input_x_; + GType *out_; int batch_size; }; @@ -853,6 +917,72 @@ class TransposeParam : public OpParam { }; #endif +#ifdef LOOKUP_OP +template +class LookupParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + LookupParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_w_ = InputWFrom(inputs, scope); + input_ids_ = InputIdsFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + padding_idx_ = GetAttr("padding_idx", attrs); + } + + const GType *InputW() const { return input_w_; } + const GType *InputIds() const { return input_ids_; } + GType *Out() const { return out_; } + int64_t PaddingIdx() const { return padding_idx_; } + + private: + GType *input_w_; + GType *input_ids_; + GType *out_; + int64_t padding_idx_; +}; +#endif + +#ifdef CRF_OP +template +class CrfParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + typedef typename DtypeTensorTrait::rtype RType; + + public: + // {G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}}, + + CrfParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + // todo crf params + input_emission_ = InputEmissionFrom(inputs, scope); + input_transition_ = InputTransitionFrom(inputs, scope); + input_label_ = InputLabelFrom(inputs, scope); + output_viterbipath_ = OutputViterbiPathFrom(outputs, scope); + // padding_idx_ = GetAttr("padding_idx", attrs); + } + const GType *InputEmission() const { return input_emission_; } + const GType *InputTransition() const { return input_transition_; } + const GType *InputLabel() const { return input_label_; } + GType *outputVBP() const { return output_viterbipath_; } + // const RType *InputIds() const { return input_ids_; } + // RType *Out() const { return out_; } + // int64_t PaddingIdx() const { return padding_idx_; } + + private: + GType *input_emission_; + GType *input_transition_; + GType *input_label_; + GType *output_viterbipath_; + + // RType *input_ids_; + // RType *out_; + // int64_t padding_idx_; +}; +#endif + #ifdef RESHAPE_OP template class ReshapeParam : public OpParam { @@ -1095,7 +1225,7 @@ class FusionFcParam : public OpParam { y_num_col_dims_ = GetAttr("y_num_col_dims", attrs); axis_ = GetAttr("axis", attrs); } - const RType *InputX() const { return input_x_; } + const GType *InputX() const { return input_x_; } #ifdef PADDLE_MOBILE_FPGA RType *InputY() const { return input_y_; } @@ -1105,7 +1235,7 @@ class FusionFcParam : public OpParam { const RType *InputZ() const { return input_z_; } - RType *Out() const { return out_; } + GType *Out() const { return out_; } const int &XNumColDims() const { return x_num_col_dims_; } @@ -1114,10 +1244,10 @@ class FusionFcParam : public OpParam { const int &Axis() const { return axis_; } private: - RType *input_x_; + GType *input_x_; RType *input_y_; RType *input_z_; - RType *out_; + GType *out_; int x_num_col_dims_; int y_num_col_dims_; int axis_; @@ -2062,5 +2192,65 @@ class ConvTransposeParam : public OpParam { }; #endif +#ifdef GRU_OP +template +class GruParam : public OpParam { + typedef typename DtypeTensorTrait::gtype GType; + + public: + /** + * + * @param inputs + * @param outputs + * @param attrs + * @param scope + * */ + GruParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_input_ = InputFrom(inputs, scope); + input_h0_ = InputH0From(inputs, scope); + input_bias_ = InputBiasFrom(inputs, scope); + input_weight_ = InputWeightFrom(inputs, scope); + + output_batch_gate_ = OutputBatchGateFrom(outputs, scope); + output_batch_reset_hidden_prev_ = + OutputBatchResetHiddenPrevFrom(outputs, scope); + output_batch_hidden_ = OutputBatchHiddenFrom(outputs, scope); + output_hidden_ = OutputHiddenFrom(outputs, scope); + activation_ = GetAttr("activation", attrs); + gate_activation_ = GetAttr("gate_activation", attrs); + is_reverse_ = GetAttr("is_reverse", attrs); + } + const GType *InputInput() const { return input_input_; } + const GType *InputWeight() const { return input_weight_; } + const GType *InputH0() const { return input_h0_; } + const GType *InputBias() const { return input_bias_; } + const std::string &Activation() const { return activation_; } + const std::string &GateActivation() const { return gate_activation_; } + const bool &IsReverse() const { return is_reverse_; } + + GType *OutBatchGate() const { return output_batch_gate_; } + GType *OutBatchResetHiddenPrev() const { + return output_batch_reset_hidden_prev_; + } + GType *OutBatchHidden() const { return output_batch_hidden_; } + GType *OutHidden() const { return output_hidden_; } + + private: + GType *input_input_; + GType *input_h0_; + GType *input_bias_; + GType *input_weight_; + + GType *output_batch_gate_; + GType *output_batch_reset_hidden_prev_; + GType *output_batch_hidden_; + GType *output_hidden_; + std::string activation_; + std::string gate_activation_; + bool is_reverse_; +}; +#endif + } // namespace operators } // namespace paddle_mobile diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bfc9fb3eb3e835ea3aebf549f579b451bcb422e8..55920770f566bc742f2e6f0a1fa9c262c17db8c2 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -190,6 +190,14 @@ else () ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h) target_link_libraries(test-conv-add-bn-relu-op paddle-mobile) + # gen test + ADD_EXECUTABLE(test-nlp net/test_nlp.cpp test_helper.h test_include.h executor_for_test.h) + target_link_libraries(test-nlp paddle-mobile) + + # gen test + ADD_EXECUTABLE(test-gru-op operators/test_gru_op.cpp test_helper.h test_include.h) + target_link_libraries(test-gru-op paddle-mobile) + #add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp) diff --git a/test/net/test_nlp.cpp b/test/net/test_nlp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ca5f6571c8786a23017bd846890d6f78345121c3 --- /dev/null +++ b/test/net/test_nlp.cpp @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "../test_helper.h" +#include "../test_include.h" + +int main() { + paddle_mobile::PaddleMobile paddle_mobile; + paddle_mobile.SetThreadNum(4); + auto time1 = time(); + // auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model", + // std::string(g_mobilenet_detect) + "/params", true); + + auto isok = paddle_mobile.Load(g_nlp, true, false, 1, true); + + // auto isok = paddle_mobile.Load(std::string(g_nlp) + "/model", + // std::string(g_nlp) + "/params", false); + if (isok) { + auto time2 = time(); + std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl; + // 1064 1603 644 699 2878 1219 867 1352 8 1 13 312 479 + + std::vector ids{1064, 1603, 644, 699, 2878, 1219, 867, + 1352, 8, 1, 13, 312, 479}; + + paddle_mobile::framework::LoDTensor words; + auto size = static_cast(ids.size()); + paddle_mobile::framework::LoD lod{{0, ids.size()}}; + DDim dims{size, 1}; + words.Resize(dims); + words.set_lod(lod); + DLOG << "words lod : " << words.lod(); + auto *pdata = words.mutable_data(); + size_t n = words.numel() * sizeof(int64_t); + DLOG << "n :" << n; + memcpy(pdata, ids.data(), n); + DLOG << "words lod 22: " << words.lod(); + auto time3 = time(); + for (int i = 0; i < 1; ++i) { + auto vec_result = paddle_mobile.PredictLod(words); + DLOG << *vec_result; + } + auto time4 = time(); + std::cout << "predict cost :" << time_diff(time3, time4) / 1 << "ms" + << std::endl; + } + return 0; +} diff --git a/test/operators/test_gru_op.cpp b/test/operators/test_gru_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..52ab8b54d709391ea263b74a395a635ce50a18af --- /dev/null +++ b/test/operators/test_gru_op.cpp @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "../test_include.h" +#include "operators/gru_op.h" + +int main() { + paddle_mobile::Loader loader; + auto program = loader.Load(g_nlp); + PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr, + "program file read fail"); + + Executor4Test> + executor(program, "gru"); + + return 0; +} diff --git a/test/test_helper.h b/test/test_helper.h index 69ffa58847f2395dec59d87abae4128d885dd19a..9c592cf1032a8f4e5c08ef7f4c7e738b6cf0b122 100644 --- a/test/test_helper.h +++ b/test/test_helper.h @@ -33,6 +33,7 @@ static const char *g_mobilenet_detect = "../models/mobilenet-detect"; static const char *g_squeezenet = "../models/squeezenet"; static const char *g_googlenet = "../models/googlenet"; static const char *g_mobilenet = "../models/mobilenet"; +static const char *g_nlp = "../models/nlp"; static const char *g_resnet_50 = "../models/resnet_50"; static const char *g_resnet = "../models/resnet"; static const char *g_googlenet_combine = "../models/googlenet_combine"; diff --git a/tools/op.cmake b/tools/op.cmake index 5965cf030fb935c89a5fb42fa72b5e810288552b..7c948e9636747690804a07ee76efe6f66c09c820 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -111,7 +111,7 @@ if ("FPGAnets" IN_LIST NET) set(FUSION_CONVBN_OP ON) set(FUSION_CONVADD_OP ON) - set(FOUND_MATCH ON) + set(FOUND_MATCH ON) endif() @@ -149,6 +149,9 @@ if(NOT FOUND_MATCH) set(SLICE_OP ON) set(DROPOUT_OP ON) set(IM2SEQUENCE_OP ON) + set(LOOKUP_OP ON) + set(GRU_OP ON) + set(CRF_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -288,3 +291,15 @@ endif() if (CONV_TRANSPOSE_OP) add_definitions(-DCONV_TRANSPOSE) endif() + +if (LOOKUP_OP) + add_definitions(-DLOOKUP_OP) +endif() + +if (GRU_OP) + add_definitions(-DGRU_OP) +endif() + +if (CRF_OP) + add_definitions(-DCRF_OP) +endif()