提交 397d0fa4 编写于 作者: xiebaiyuan's avatar xiebaiyuan

merge nlp to main

上级 63f71678
......@@ -56,6 +56,9 @@ const char *G_OP_TYPE_REGION = "region";
const char *G_OP_TYPE_FUSION_CONV_BN = "fusion_conv_bn";
const char *G_OP_TYPE_CONV_TRANSPOSE = "conv2d_transpose";
const char *G_OP_TYPE_PRELU = "prelu";
const char *G_OP_TYPE_LOOKUP_TABLE = "lookup_table";
const char *G_OP_TYPE_GRU = "gru";
const char *G_OP_TYPE_CRF = "crf_decoding";
std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
......@@ -97,6 +100,11 @@ std::unordered_map<
{G_OP_TYPE_FUSION_FC_RELU, {{"X", "Y", "Z"}, {"Out"}}},
{G_OP_TYPE_REGION, {{"X"}, {"Out"}}},
{G_OP_TYPE_FUSION_CONV_BN, {{"Input"}, {"Y"}}},
{G_OP_TYPE_LOOKUP_TABLE, {{"W", "Ids"}, {"Out"}}},
{G_OP_TYPE_GRU,
{{"Input", "H0", "Weight", "Bias"},
{"BatchGate", "BatchResetHiddenPrev", "BatchHidden", "Hidden"}}},
{G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}},
{G_OP_TYPE_CONV_TRANSPOSE, {{"Input"}, {"Output"}}}};
} // namespace paddle_mobile
......@@ -62,7 +62,7 @@ void OperatorBase<Dtype>::Run() const {
vector<string> input_keys = GetInputKeys();
for (const auto key : input_keys) {
Tensor *input = GetVarValue<framework::LoDTensor>(key, inputs_, *scope_);
DLOG << type_ << " input- " << key << "=" << *input;
if (input) DLOG << type_ << " input- " << key << "=" << *input;
}
vector<string> output_keys = GetOutKeys();
for (const auto key : output_keys) {
......
......@@ -339,7 +339,12 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
stride = stride > 0 ? stride : 1;
#ifndef PADDLE_MOBILE_FPGA
for (int i = 0; i < tensor.numel(); i += stride) {
printer << tensor.data<float>()[i] << " ";
// 这不一定是float的
if (tensor.type() == typeid(float)) {
printer << tensor.data<float>()[i] << " ";
} else if (tensor.type() == typeid(int64_t)) {
printer << tensor.data<int64_t>()[i] << " ";
}
}
#endif
......
......@@ -54,8 +54,11 @@ char *Get_binary_data(std::string filename) {
#pragma mark - executor
template <typename Dtype, Precision P>
Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
bool use_optimize)
: program_(p), batch_size_(batch_size), use_optimize_(use_optimize) {
bool use_optimize, bool loddable)
: program_(p),
batch_size_(batch_size),
use_optimize_(use_optimize),
loddable_(loddable) {
if (use_optimize_) {
to_predict_program_ = program_.optimizeProgram;
} else {
......@@ -79,7 +82,12 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
auto op_base = framework::OpRegistry<Dtype>::CreateOp(
op->Type(), op->GetInputs(), op->GetOutputs(), op->GetAttrMap(),
program_.scope);
op_base->InferShape();
DLOG << "executer in loaddable mode: " << loddable_;
// use pre_infershape to pre resize , but if u use an lod mode tensor u
// need to resize in runtime
if (!loddable_) {
op_base->InferShape();
}
ops_of_block_[*block_desc.get()].push_back(op_base);
#ifdef PADDLE_EXECUTOR_MULTITHREAD
depManager[i].analysisDep(ops_of_block_[*block_desc.get()]);
......@@ -225,9 +233,18 @@ void Executor<Dtype, P>::InitMemory() {
delete origin_data;
} else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto tensor = var->template GetMutable<framework::LoDTensor>();
tensor->template mutable_data<Ptype>();
DLOG << "var_desc->Name(): " << var_desc->Name();
DLOG << "var_desc->Tensor_desc().DataType(): "
<< var_desc->Tensor_desc().DataType();
bool is_mute_match;
framework::LoDTensor *tensor = nullptr;
is_mute_match = varInputMemory(var_desc, var, tensor);
PADDLE_MOBILE_ENFORCE(
is_mute_match,
"got unhandled var_desc->Tensor_desc().DataType(): %d",
var_desc->Tensor_desc().DataType());
}
}
}
......@@ -257,8 +274,18 @@ void Executor<Dtype, P>::InitCombineMemory() {
LoadMemory(*var_desc, tensor, &data);
} else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto tensor = var->template GetMutable<framework::LoDTensor>();
tensor->template mutable_data<Ptype>();
DLOG << "var_desc->Name(): " << var_desc->Name();
DLOG << "var_desc->Tensor_desc().DataType(): "
<< var_desc->Tensor_desc().DataType();
bool is_mute_match = false;
framework::LoDTensor *tensor;
is_mute_match = varInputMemory(var_desc, var, tensor);
PADDLE_MOBILE_ENFORCE(
is_mute_match,
"got unhandled var_desc->Tensor_desc().DataType(): %d",
var_desc->Tensor_desc().DataType());
}
}
}
......@@ -266,6 +293,46 @@ void Executor<Dtype, P>::InitCombineMemory() {
delete origin_data;
LOG(kLOG_INFO) << " end init combine memory ";
}
template <typename Dtype, Precision P>
bool Executor<Dtype, P>::varInputMemory(
const std::shared_ptr<framework::VarDesc> &var_desc, Variable *var,
framework::LoDTensor *tensor) const {
bool is_mute_match = false;
switch (var_desc->Tensor_desc().DataType()) {
case framework::VARTYPE_TYPE_FP16: {
break;
}
case framework::VARTYPE_TYPE_FP32: {
tensor = var->template GetMutable<framework::LoDTensor>();
tensor->template mutable_data<Ptype>();
is_mute_match = true;
break;
}
case framework::VARTYPE_TYPE_FP64: {
break;
}
case framework::VARTYPE_TYPE_INT32: {
break;
}
case framework::VARTYPE_TYPE_INT64: {
tensor = var->template GetMutable<framework::LoDTensor>();
tensor->template mutable_data<int64_t>();
is_mute_match = true;
break;
}
case framework::VARTYPE_TYPE_BOOL: {
break;
}
default: { break; }
}
return is_mute_match;
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
......@@ -278,6 +345,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
#ifdef PADDLE_MOBILE_PROFILE
std::vector<ProfInfo> profile(ops.size());
#endif
......@@ -342,6 +410,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
DLOG << "executer Predict in3.3";
// to Run
ops[i]->Run();
......@@ -351,6 +420,8 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
#endif
}
#endif
DLOG << "executer Predict in4";
auto last_op = ops.rbegin();
auto output_map = (*last_op)->Outputs();
......@@ -377,6 +448,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
fprintf(df, "}\n");
fclose(df);
#endif
DLOG << "executer Predict in5";
// FILE *pf = fopen("profile.out", "w");
std::unordered_map<std::string, uint64_t> _tp;
......@@ -389,6 +461,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
// pInfo.tid, pInfo.runBegin, pInfo.runEnd, timeCost);
}
// fclose(pf);
DLOG << "executer Predict in6";
printf("====================[ profile ]======================\n");
using prof_t = std::pair<std::string, uint64_t>;
......@@ -409,9 +482,184 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
}
printf("====================[---------]======================\n");
#endif
DLOG << "executer Predict out";
return std::make_shared<framework::Tensor>(framework::Tensor(*output_tensor));
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::LoDTensor> Executor<Dtype, P>::PredictLod(
const framework::LoDTensor &t) {
DLOG << "execute PredictLod :lod" << t.lod();
DLOG << "executer Predict in";
framework::Variable *g_feed_value = program_.scope->Var("feed");
framework::LoDTensor *feed_tensor =
g_feed_value->GetMutable<framework::LoDTensor>();
DLOG << "executer Predict in2";
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
feed_tensor->set_lod(t.lod());
DLOG << "feed_tensor .lod : " << feed_tensor->lod();
DLOG << "executer Predict in3";
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
DLOG << "executer Predict in3.1";
auto &ops = ops_of_block_[*to_predict_block.get()];
DLOG << "executer Predict in3.2";
#ifdef PADDLE_MOBILE_PROFILE
std::vector<ProfInfo> profile(ops.size());
#endif
#ifdef PADDLE_EXECUTOR_MULTITHREAD
std::mutex m;
std::condition_variable cv;
std::queue<int> next;
next.push(0);
int rsize = ops.size();
std::vector<int> status(rsize, 0);
auto &threadPool = ThreadPool::getThreadPool();
auto &dep = depManager[0];
auto finishF = [&ops, &m, &cv, &next, &status, &rsize, &dep](int opi) {
std::lock_guard<std::mutex> lk(m);
rsize--;
status[opi] = 2;
for (int i : dep.getNext(opi)) {
bool ok = true;
for (int j : dep.getDeps(i)) {
if (status[j] != 2) {
ok = false;
break;
}
}
if (ok && (status[i] == 0)) {
next.push(i);
}
}
cv.notify_one();
};
for (;;) {
std::unique_lock<std::mutex> lk(m);
cv.wait(lk, [&next, &rsize] { return rsize == 0 || !next.empty(); });
if (rsize == 0) {
break;
}
while (next.size() > 0) {
int opi = next.front();
next.pop();
status[opi] = 1;
threadPool.enqueue([opi, &ops, &finishF, &profile] {
auto &op = ops[opi];
#ifdef PADDLE_MOBILE_PROFILE
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[opi].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
profile[opi].tid = ThreadPool::getThreadPoolThreadId();
#endif
ops[opi]->Run();
#ifdef PADDLE_MOBILE_PROFILE
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[opi].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
finishF(opi);
});
}
}
#else
for (int i = 0; i < ops.size(); i++) {
#ifdef PADDLE_MOBILE_PROFILE
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
DLOG << "executer Predict in3.3 infer";
if (loddable_) {
ops[i]->InferShape();
}
DLOG << "executer Predict in3.3 after infer";
// to Run
ops[i]->Run();
#ifdef PADDLE_MOBILE_PROFILE
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
}
#endif
DLOG << "executer Predict in4";
auto last_op = ops.rbegin();
auto output_map = (*last_op)->Outputs();
std::vector<std::string> out_keys = (*last_op)->GetOutKeys();
PADDLE_MOBILE_ENFORCE(out_keys.size() > 0, "the last op contains no output");
framework::LoDTensor *output_tensor =
framework::GetVarValue<framework::LoDTensor>(out_keys[0], output_map,
*(program_.scope));
#ifdef PADDLE_MOBILE_PROFILE
#ifdef PADDLE_EXECUTOR_MULTITHREAD
// TODO(haipeng): expose profile info as an interface, user can get them to
// analysis
// the performance of their deepnet.
FILE *df = fopen("net.dot", "w");
fprintf(df, "digraph {\n");
for (int i = 0; i < ops.size(); i++) {
for (int j : dep.getNext(i)) {
fprintf(df, "op_%d -> op_%d\n", i, j);
}
}
for (int i = 0; i < ops.size(); i++) {
fprintf(df, "op_%d[label=\"%s (%d)\"]\n", i, ops[i]->Type().c_str(), i);
}
fprintf(df, "}\n");
fclose(df);
#endif
DLOG << "executer Predict in5";
// FILE *pf = fopen("profile.out", "w");
std::unordered_map<std::string, uint64_t> _tp;
for (int i = 0; i < profile.size(); i++) {
const auto &pInfo = profile[i];
uint64_t timeCost = pInfo.runEnd - pInfo.runBegin;
_tp[ops[i]->Type()] += timeCost;
// fprintf(pf, "%d\t%s\t%d\t%llu\t%llu\t%llu\n", i,
// ops[i]->Type().c_str(),
// pInfo.tid, pInfo.runBegin, pInfo.runEnd, timeCost);
}
// fclose(pf);
DLOG << "executer Predict in6";
printf("====================[ profile ]======================\n");
using prof_t = std::pair<std::string, uint64_t>;
std::vector<prof_t> _tv(_tp.begin(), _tp.end());
uint64_t _ptotal = 0;
for (auto const &p : _tv) {
_ptotal += p.second;
}
auto compf = [](const prof_t &a, const prof_t &b) {
return a.second > b.second;
};
std::sort(_tv.begin(), _tv.end(), compf);
_tv.push_back(std::make_pair("total", _ptotal));
for (auto const &p : _tv) {
printf("%-16s\t%-10.0f\t%-2.4f\n", p.first.c_str(),
static_cast<float>(p.second),
static_cast<float>(p.second) / _ptotal * 100.0);
}
printf("====================[---------]======================\n");
#endif
DLOG << "executer Predict out";
return std::make_shared<framework::LoDTensor>(
framework::LoDTensor(*output_tensor));
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
const framework::Tensor &t, int block_id) {
......
......@@ -43,13 +43,17 @@ class Executor {
* @b 用 loader load 的 program 实例化 executor
* */
Executor(const framework::Program<Dtype> p, int batch_size = 1,
bool use_optimize = true);
bool use_optimize = true, bool loddable = false);
/*
* @b to predict
* */
std::shared_ptr<framework::Tensor> Predict(const framework::Tensor &t);
/*
* @b to predict
* */
std::shared_ptr<framework::LoDTensor> PredictLod(
const framework::LoDTensor &t);
/*
* @b to predict with vector and dim
*
......@@ -73,6 +77,7 @@ class Executor {
std::vector<std::shared_ptr<framework::OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
bool loddable_ = false;
#ifdef PADDLE_EXECUTOR_MULTITHREAD
std::vector<depCore> depManager;
#endif
......@@ -83,6 +88,10 @@ class Executor {
uint64_t runEnd = 0UL;
};
#endif
bool varInputMemory(const std::shared_ptr<framework::VarDesc> &var_desc,
framework::Variable *var,
framework::LoDTensor *tensor) const;
};
} // namespace paddle_mobile
......@@ -26,7 +26,8 @@ void PaddleMobile<Dtype, P>::SetThreadNum(int num) {
template <typename Dtype, Precision P>
bool PaddleMobile<Dtype, P>::Load(const std::string &dirname, bool optimize,
bool quantification, int batch_size) {
bool quantification, int batch_size,
bool loddable) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<Loader<Dtype, P>>();
} else {
......@@ -35,7 +36,8 @@ bool PaddleMobile<Dtype, P>::Load(const std::string &dirname, bool optimize,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<Executor<Dtype, P>>(
loader_->Load(dirname, optimize, quantification), batch_size, optimize);
loader_->Load(dirname, optimize, quantification), batch_size, optimize,
loddable);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -46,7 +48,8 @@ bool PaddleMobile<Dtype, P>::Load(const std::string &dirname, bool optimize,
template <typename Dtype, Precision P>
bool PaddleMobile<Dtype, P>::Load(const std::string &model_path,
const std::string &para_path, bool optimize,
bool quantification, int batch_size) {
bool quantification, int batch_size,
bool loddable) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<Loader<Dtype, P>>();
} else {
......@@ -56,7 +59,7 @@ bool PaddleMobile<Dtype, P>::Load(const std::string &model_path,
if (executor_.get() == nullptr) {
executor_ = std::make_shared<Executor<Dtype, P>>(
loader_->Load(model_path, para_path, optimize, quantification),
batch_size, optimize);
batch_size, optimize, loddable);
} else {
LOG(kLOG_INFO) << "executor inited";
}
......@@ -96,6 +99,12 @@ std::shared_ptr<framework::Tensor> PaddleMobile<Dtype, P>::Predict(
return executor_->Predict(t);
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> PaddleMobile<Dtype, P>::PredictLod(
const framework::LoDTensor &t) {
return executor_->PredictLod(t);
}
template <typename Dtype, Precision P>
std::vector<typename PaddleMobile<Dtype, P>::Ptype>
PaddleMobile<Dtype, P>::Predict(const std::vector<Ptype> &input,
......
......@@ -39,7 +39,8 @@ class PaddleMobile {
* @b 加载分开形式的 fluid 模型
* */
bool Load(const std::string &dirname, bool optimize = false,
bool quantification = false, int batch_size = 1);
bool quantification = false, int batch_size = 1,
bool loddable = false);
/*
* @b load combine format fluid mode
......@@ -47,7 +48,7 @@ class PaddleMobile {
* */
bool Load(const std::string &model_path, const std::string &para_path,
bool optimize = false, bool quantification = false,
int batch_size = 1);
int batch_size = 1, bool loddable = false);
/*
* @b 设置线程数, 当 cmake 中开启 openmp 时生效
* */
......@@ -58,6 +59,11 @@ class PaddleMobile {
* */
std::shared_ptr<framework::Tensor> Predict(const framework::Tensor &t);
/*
* @b to predict
* */
std::shared_ptr<framework::Tensor> PredictLod(const framework::LoDTensor &t);
/*
* @b to predict with vector and dim
*
......
......@@ -353,6 +353,41 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
return result;
}
JNIEXPORT jlongArray JNICALL
Java_com_baidu_paddle_PML_predictLod(JNIEnv *env, jclass thiz, jlongArray buf) {
std::lock_guard<std::mutex> lock(shared_mutex);
jlong *ddim_ptr = env->GetLongArrayElements(buf, NULL);
jsize ddim_size = env->GetArrayLength(buf);
std::vector<int64_t> ids;
for (int i = 0; i < ddim_size; ++i) {
jlong x = ddim_ptr[i];
ids.push_back((int64_t)x);
}
paddle_mobile::framework::LoDTensor words;
auto size = static_cast<int>(ids.size());
paddle_mobile::framework::LoD lod{{0, ids.size()}};
DDim dims{size, 1};
words.Resize(dims);
words.set_lod(lod);
auto *pdata = words.mutable_data<int64_t>();
size_t n = words.numel() * sizeof(int64_t);
memcpy(pdata, ids.data(), n);
auto vec_result = paddle_mobile.PredictLod(words);
int count = vec_result->numel();
jlongArray result = NULL;
ANDROIDLOGE("predict nlp size %d", count);
result = env->NewLongArray(count);
env->SetLongArrayRegion(result, 0, count, vec_result->data<int64_t>());
return result;
}
JNIEXPORT void JNICALL Java_com_baidu_paddle_PML_setThread(JNIEnv *env,
jclass thiz,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CRF_OP
#include <vector>
#include "common/enforce.h"
#include "operators/crf_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void CrfOp<Dtype, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(this->param_.InputEmission(),
"Input(Emission) should be not null.");
PADDLE_MOBILE_ENFORCE(this->param_.InputTransition(),
"Input(Transition) should be not null.");
PADDLE_MOBILE_ENFORCE(this->param_.outputVBP(),
"Input(ViterbiPath) should be not null.");
auto emission_dims = this->param_.InputEmission()->dims();
PADDLE_MOBILE_ENFORCE(emission_dims.size() == 2U,
"The Input(Emission) should be a 2-D tensor.");
PADDLE_MOBILE_ENFORCE(emission_dims[0],
"An empty mini-batch is not allowed.");
this->param_.outputVBP()->Resize(
{this->param_.InputEmission()->dims()[0], 1});
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(crf_decoding, ops::CrfOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CRF_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/crf_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class CrfOp : public framework::OperatorWithKernel<
DeviceType, CrfParam<DeviceType>,
operators::CrfKernel<DeviceType, T>> {
public:
CrfOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, CrfParam<DeviceType>,
operators::CrfKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, CrfParam<DeviceType>,
operators::CrfKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(crf_decoding);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
......@@ -35,6 +35,10 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
param_.Out()->Resize(out_dims);
// note : mobile infershape iscalled when executer is created. so do not
// pass lod here .
// it is empty
}
#ifdef PADDLE_MOBILE_FPGA
......@@ -67,7 +71,10 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
#else
void Init() {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void RunImpl() const {
param_.Out()->ShareDataWith(*param_.InputX());
param_.Out()->set_lod(param_.InputX()->lod());
}
#endif
protected:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#include <vector>
#include "common/enforce.h"
#include "operators/gru_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void GruOp<Dtype, T>::InferShape() const {
auto lod_size = this->param_.InputInput()->lod().size();
PADDLE_MOBILE_ENFORCE((lod_size == 1),
"Current LoD only supports one dimension.");
auto input_dims = this->param_.InputInput()->dims();
auto weight_dims = this->param_.InputWeight()->dims();
int input_size = input_dims[1];
int frame_size = weight_dims[0];
PADDLE_MOBILE_ENFORCE(
(input_size == frame_size * 3),
"The input_size must be 3 times of frame_size in GRUOp.");
PADDLE_MOBILE_ENFORCE(
(weight_dims[1] == frame_size * 3),
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
if (this->param_.InputH0()) {
auto h0_dims = this->param_.InputH0()->dims();
PADDLE_MOBILE_ENFORCE((h0_dims[1] == frame_size),
"The width of H0 must be equal to frame_size.");
}
if (this->param_.InputBias()) {
auto bias_dims = this->param_.InputBias()->dims();
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_MOBILE_ENFORCE((bias_height == 1),
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_MOBILE_ENFORCE((bias_width == frame_size * 3),
"The shape of Bias must be [1, frame_size * 3].");
}
this->param_.OutBatchGate()->Resize(input_dims);
this->param_.OutBatchResetHiddenPrev()->Resize({input_dims[0], frame_size});
this->param_.OutBatchHidden()->Resize({input_dims[0], frame_size});
this->param_.OutHidden()->Resize({input_dims[0], frame_size});
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(gru, ops::GruOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/gru_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class GruOp : public framework::OperatorWithKernel<
DeviceType, GruParam<DeviceType>,
operators::GruKernel<DeviceType, T>> {
public:
GruOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, GruParam<DeviceType>,
operators::GruKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, GruParam<DeviceType>,
operators::GruKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(gru);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
......@@ -28,6 +28,7 @@ bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) const {
ConcatCompute<float>(param);
param.Out()->set_lod(param.Inputs()[0]->lod());
}
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CRF_OP
#include "operators/kernel/crf_kernel.h"
#include "common/types.h"
#include "operators/kernel/central-arm-func/crf_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool CrfKernel<CPU, float>::Init(CrfParam<CPU> *param) {
return true;
}
template <>
void CrfKernel<CPU, float>::Compute(const CrfParam<CPU> &param) const {
CrfCompute<float>(param);
}
template class CrfKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -29,6 +29,7 @@ template <>
void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam<CPU> &param) const {
ElementwiseAddCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
} // namespace operators
......
......@@ -29,6 +29,7 @@ template <>
void FusionFcKernel<CPU, float>::Compute(
const FusionFcParam<CPU> &param) const {
FusionFcCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#include "operators/kernel/gru_kernel.h"
#include "operators/kernel/central-arm-func/gru_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool GruKernel<CPU, float>::Init(GruParam<CPU> *param) {
return true;
}
template <>
void GruKernel<CPU, float>::Compute(const GruParam<CPU> &param) const {
GruCompute<float>(param);
param.OutHidden()->set_lod(param.InputInput()->lod());
// DLOG << "________________" << param.OutHidden()->dims();
// DLOG << "________________" << param.OutHidden()->numel();
// auto *hiden_data = param.OutHidden()->data<float>();
// for (int64_t i = 0; i < 10; i++) {
// DLOG << "****************" << hiden_data[i];
// }
}
template class GruKernel<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LOOKUP_OP
#include "operators/kernel/lookup_kernel.h"
#include "operators/kernel/central-arm-func/lookup_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool LookupKernel<CPU, float>::Init(LookupParam<CPU> *param) {
return true;
}
template <>
void LookupKernel<CPU, float>::Compute(const LookupParam<CPU> &param) const {
LookupCompute<float>(param);
param.Out()->set_lod(param.InputIds()->lod());
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -28,6 +28,7 @@ bool MulKernel<CPU, float>::Init(MulParam<CPU> *param) {
template <>
void MulKernel<CPU, float>::Compute(const MulParam<CPU> &param) const {
MulCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CRF_OP
#pragma once
#include <limits>
#include <vector>
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename P>
void Decode(const Tensor& emission_weights, const Tensor& transition_weights,
Tensor* decoded_path) {
auto emission_dims = emission_weights.dims();
const size_t seq_len = emission_dims[0];
const size_t tag_num = emission_dims[1];
const size_t state_trans_base_idx = 2;
const P* x = emission_weights.data<P>();
const P* w = transition_weights.data<P>();
int64_t* path = decoded_path->data<int64_t>();
// alpha is a memo table. An element alpha(k, v) records the score of the
// best sequence of tags from position 1 to position k with v being the end
// tag.
Tensor alpha;
P* alpha_value = alpha.mutable_data<P>(emission_dims);
Tensor track;
int* track_value = track.mutable_data<int>(emission_dims);
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
for (size_t k = 1; k < seq_len; ++k) {
for (size_t i = 0; i < tag_num; ++i) {
P max_score = -std::numeric_limits<P>::max();
int max_j = 0;
for (size_t j = 0; j < tag_num; ++j) {
P score = alpha_value[(k - 1) * tag_num + j] +
w[(j + state_trans_base_idx) * tag_num + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
track_value[k * tag_num + i] = max_j;
}
}
P max_score = -std::numeric_limits<P>::max();
int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) {
P score = alpha_value[(seq_len - 1) * tag_num + i] + w[tag_num + i];
if (score > max_score) {
max_score = score;
max_i = i;
}
}
path[seq_len - 1] = max_i;
for (int k = seq_len - 1; k >= 1; --k) {
path[k - 1] = max_i = track_value[k * tag_num + max_i];
}
}
template <typename P>
void CrfCompute(const CrfParam<CPU>& param) {
auto* emission = param.InputEmission();
auto* transition = param.InputTransition();
auto* label = param.InputLabel();
auto* decoded_path = param.outputVBP();
// DLOG<<*emission;
// DLOG<<*transition;
// DLOG<<*label;
PADDLE_MOBILE_ENFORCE(emission->NumLevels() == 1U,
"The Input(Emission) should be a sequence.");
auto lod = emission->lod();
PADDLE_MOBILE_ENFORCE(lod.size(),
"The Input(Emission) should be a sequence.");
const size_t level = 0;
const size_t seq_num = lod[level].size() - 1;
int64_t* path = decoded_path->mutable_data<int64_t>();
int numel = decoded_path->numel();
memset(static_cast<void*>(path), 0, sizeof(int64_t) * numel);
for (size_t i = 0; i < seq_num; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor decoded_path_one_seq = decoded_path->Slice(start_pos, end_pos);
Decode<P>(emission->Slice(start_pos, end_pos), *transition,
&decoded_path_one_seq);
}
if (label) {
PADDLE_MOBILE_ENFORCE(label->NumLevels() == 1U,
"The Input(Label) should be a sequence.");
const int64_t* label_value = label->data<int64_t>();
size_t batch_size = emission->dims()[0];
for (size_t i = 0; i < batch_size; ++i) {
path[i] = label_value[i] == path[i] ? 1 : 0;
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -30,6 +30,9 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
int axis = param.Axis();
Tensor *out = param.Out();
auto *out_data = out->mutable_data<float>();
// int m = out->dims()[0];
// int n = out->dims()[1];
const Tensor x_matrix =
input_x->dims().size() > 2
? framework::ReshapeToMatrix(*input_x, param.XNumColDims())
......@@ -57,6 +60,7 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
// for (int i = 0; i < out->numel(); i++) {
// DLOG << out_data[i];
// }
// bias_data的维度和out的维度一致
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1), false);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#pragma once
#include <operators/math/sequence2batch.h>
#include <vector>
#include "common/types.h"
#include "operators/math/gru_compute.h"
#include "operators/math/math_function.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename DeviceType, typename T>
inline void ReorderInitState(const framework::Tensor& src,
std::vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceType, T> row_shuffle;
dst->mutable_data<T>(src.dims());
row_shuffle(src, index_lod, dst, indexed_src);
}
template <typename P>
void GruCompute(const GruParam<CPU>& param) {
auto* input = param.InputInput();
auto* h0 = param.InputH0();
auto* weight = param.InputWeight();
const auto* weight_data = weight->data<float>();
auto* bias = param.InputBias();
auto* batch_gate = param.OutBatchGate();
batch_gate->mutable_data<float>();
auto* batch_reset_hidden_prev = param.OutBatchResetHiddenPrev();
batch_reset_hidden_prev->mutable_data<float>();
auto* batch_hidden = param.OutBatchHidden();
batch_hidden->mutable_data<float>();
auto* hidden = param.OutHidden();
hidden->mutable_data<float>();
auto hidden_dims = hidden->dims();
bool is_reverse = param.IsReverse();
math::LoDTensor2BatchFunctor<CPU, float> to_batch;
to_batch(*input, batch_gate, true, is_reverse);
// math::ClearTensor<CPU, float> clearTensor;
// clearTensor(batch_gate);
if (bias) {
math::RowwiseAdd<CPU, float> add_bias;
add_bias(*batch_gate, *bias, batch_gate);
}
int frame_size = hidden_dims[1];
math::GRUMetaValue<float> gru_value;
gru_value.gate_weight = const_cast<float*>(weight_data);
gru_value.state_weight =
const_cast<float*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
std::vector<size_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<CPU, float>(*h0, order, &ordered_h0, true);
gru_value.prev_out_value = ordered_h0.data<float>();
} else {
gru_value.prev_out_value = nullptr;
}
auto batch_starts = batch_gate->lod()[0];
size_t seq_len = batch_starts.size() - 1;
auto active_node = math::GetActivationType(param.Activation());
auto active_gate = math::GetActivationType(param.GateActivation());
for (size_t n = 0; n < seq_len; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend); // BUG
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<float>();
gru_value.gate_value = gate_t.data<float>();
gru_value.reset_output_value = reset_hidden_prev_t.data<float>();
math::GRUUnitFunctor<CPU, float>::compute(
gru_value, frame_size, cur_batch_size, active_node, active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
math::Batch2LoDTensorFunctor<CPU, float> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(*batch_hidden, hidden);
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LOOKUP_OP
#pragma once
#include <vector>
#include "framework/ddim.h"
#include "operators/op_param.h"
constexpr int64_t kNoPadding = -1;
namespace paddle_mobile {
namespace operators {
template <typename P>
void LookupCompute(const LookupParam<CPU> &param) {
auto *ids_t = param.InputIds();
auto *table_t = param.InputW();
auto *output_t = param.Out();
int64_t padding_idx = param.PaddingIdx();
const framework::DDim &table_dim = table_t->dims();
int64_t ids_numel;
const auto *ids = ids_t->data<int64_t>();
ids_numel = ids_t->numel();
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<float>();
auto *output = output_t->mutable_data<float>();
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(float));
} else {
PADDLE_MOBILE_ENFORCE(ids[i] < row_number,
"look uptable ids[i] <row_number check failed");
PADDLE_MOBILE_ENFORCE(ids[i] >= 0,
"lookuptable ids[i] >= 0 check failed");
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(float));
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CRF_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class CrfKernel
: public framework::OpKernelBase<DeviceType, CrfParam<DeviceType>> {
public:
void Compute(const CrfParam<DeviceType>& param) const;
bool Init(CrfParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class GruKernel
: public framework::OpKernelBase<DeviceType, GruParam<DeviceType>> {
public:
void Compute(const GruParam<DeviceType>& param) const;
bool Init(GruParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LOOKUP_OP
#pragma once
#include <vector>
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class LookupKernel
: public framework::OpKernelBase<DeviceType, LookupParam<DeviceType>> {
public:
void Compute(const LookupParam<DeviceType>& param) const;
bool Init(LookupParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LOOKUP_OP
#include <vector>
#include "common/enforce.h"
#include "operators/lookup_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void LookupOp<Dtype, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(this->param_.InputW() != nullptr,
"Input(W) of LookupTableOp should not be null.");
auto *ids_t = this->param_.InputIds();
PADDLE_MOBILE_ENFORCE(ids_t != nullptr,
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr,
"Output(Out) of LookupTableOp should not be null.");
// this->param__.InputW()->
auto table_dims = this->param_.InputW()->dims();
auto ids_dims = ids_t->dims();
int ids_rank = ids_dims.size();
PADDLE_MOBILE_ENFORCE(table_dims.size() == 2,
"table_dims.size()==2 check failed");
PADDLE_MOBILE_ENFORCE(ids_dims[ids_rank - 1] == 1,
"The last dimension of the 'Ids' tensor must be 1.");
auto output_dims =
framework::vectorize(framework::slice_ddim(ids_dims, 0, ids_rank - 1));
output_dims.push_back(table_dims[1]);
this->param_.Out()->Resize(framework::make_ddim(output_dims));
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(lookup_table, ops::LookupOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef LOOKUP_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/lookup_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class LookupOp : public framework::OperatorWithKernel<
DeviceType, LookupParam<DeviceType>,
operators::LookupKernel<DeviceType, T>> {
public:
LookupOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, LookupParam<DeviceType>,
operators::LookupKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, LookupParam<DeviceType>,
operators::LookupKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#ifdef PADDLE_MOBILE_CPU
USE_OP_CPU(lookup_table);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <string>
#include "common/enforce.h"
namespace paddle_mobile {
namespace operators {
namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
enum ActivationType {
kSigmoid,
kReLU,
kTanh,
kIdentity,
};
inline ActivationType GetActivationType(const std::string &type) {
if (type == "sigmoid") {
return ActivationType::kSigmoid;
} else if (type == "relu") {
return ActivationType::kReLU;
} else if (type == "tanh") {
return ActivationType::kTanh;
} else if (type == "identity" || type == "") {
return ActivationType::kIdentity;
}
PADDLE_MOBILE_THROW_EXCEPTION("Not support activation type.");
}
namespace forward {
template <typename T>
T Identity(const T a) {
return a;
}
template <typename T>
T Relu(const T a) {
return a > static_cast<T>(0.0) ? a : static_cast<T>(0.0);
}
template <typename T>
T Sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
T Tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
} // namespace forward
template <typename T>
struct Active {
typedef T (*Act)(T);
};
static Active<float>::Act kActFloat[] = {
&forward::Sigmoid<float>, &forward::Relu<float>, &forward::Tanh<float>,
&forward::Identity<float>};
namespace forward {
inline float activation(float a, int index) { return kActFloat[index](a); }
} // namespace forward
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#include "operators/math/gru_compute.h"
#include "common/types.h"
#include "operators/math/activation_functions.h"
#include "operators/math/gemm.h"
#include "operators/math/gru_cpu_kernel.h"
#include "operators/math/gru_kernel.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename T>
struct GRUUnitFunctor<CPU, T> {
static void compute(GRUMetaValue<T> value, int frame_size, int batch_size,
const ActivationType active_node,
const ActivationType active_gate) {
if (value.prev_out_value) {
Sgemm(batch_size, frame_size * 2, frame_size, 1, value.prev_out_value,
frame_size, value.gate_weight, frame_size * 2, 1, value.gate_value,
frame_size * 3, false, nullptr);
}
forward_reset_output(forward::gru_resetOutput<T>(), value, frame_size,
batch_size, active_gate);
if (value.prev_out_value) {
Sgemm(batch_size, frame_size, frame_size, 1, value.reset_output_value,
frame_size, value.state_weight, frame_size, 1,
value.gate_value + frame_size * 2, frame_size * 3, false, nullptr);
}
forward_final_output(forward::gru_finalOutput<T>(), value, frame_size,
batch_size, active_node);
}
};
template struct GRUUnitFunctor<CPU, float>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#pragma once
#include "operators/math/activation_functions.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename T>
struct GRUMetaValue {
T *gate_weight;
T *state_weight;
T *gate_value;
T *reset_output_value;
T *output_value;
T *prev_out_value;
};
template <typename DeviceType, typename T>
struct GRUUnitFunctor {
static void compute(GRUMetaValue<T> value, int frame_size, int batch_size,
const ActivationType active_node,
const ActivationType active_gate);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#pragma once
#include <type_traits>
#include "operators/math/activation_functions.h"
#include "operators/math/gru_compute.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
ActivationType active_gate) {
T r_value_update_gate;
T r_value_reset_gate;
T r_value_reset_output;
T r_prev_out = 0;
T *update_gate = gate_value;
T *reset_gate = gate_value + frame_size;
for (int i = 0; i < frame_size; i++) {
r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i];
if (prev_output_value) {
r_prev_out = prev_output_value[i];
}
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
&r_value_reset_output, active_gate);
update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate;
reset_output_value[i] = r_value_reset_output;
}
}
template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
ActivationType active_node) {
T r_value_update_gate;
T r_value_frame_state;
T r_prev_out = 0;
T r_output;
T *update_gate = gate_value;
T *frame_state = gate_value + frame_size * 2;
for (int i = 0; i < frame_size; i++) {
r_value_update_gate = update_gate[i];
r_value_frame_state = frame_state[i];
if (prev_output_value) {
r_prev_out = prev_output_value[i];
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
frame_state[i] = r_value_frame_state;
output_value[i] = r_output;
}
}
template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output,
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) {
hl_naive_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate);
value.gate_value += frame_size * 3;
value.reset_output_value += frame_size;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
}
}
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
for (int b = 0; b < batch_size; b++) {
hl_naive_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value,
frame_size, active_node);
value.gate_value += frame_size * 3;
value.output_value += frame_size;
if (value.prev_out_value) {
value.prev_out_value += frame_size;
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef GRU_OP
#pragma once
#include <type_traits>
#include "operators/math/activation_functions.h"
namespace paddle_mobile {
namespace operators {
namespace math {
namespace forward {
template <typename T>
class gru_resetOutput {
public:
void operator()(T *value_update_gate, T *value_reset_gate, T *prev_out,
T *value_reset_output, ActivationType act_gate) {
*value_update_gate = activation(*value_update_gate, act_gate);
*value_reset_gate = activation(*value_reset_gate, act_gate);
*value_reset_output = (*prev_out) * (*value_reset_gate);
}
};
template <typename T>
class gru_finalOutput {
public:
void operator()(T *value_update_gate, T *value_frame_state, T *prev_out,
T *value_output, ActivationType act_input) {
*value_frame_state = activation(*value_frame_state, act_input);
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state));
}
};
} // namespace forward
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/math_function.h"
#include <cstring>
#include "operators/math/gemm.h"
namespace paddle_mobile {
......@@ -119,6 +120,40 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
#endif
}
template <typename T>
struct ClearTensor<CPU, T> {
void operator()(framework::Tensor *tensor) {
auto size = tensor->numel();
auto *tensor_data = tensor->data<float>();
memset((void *)tensor_data, 0, sizeof(T) * size);
}
};
template <typename T>
struct RowwiseAdd<CPU, T> {
void operator()(const framework::Tensor &input,
const framework::Tensor &vector, framework::Tensor *output) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_MOBILE_ENFORCE((vector.numel() == size),
"vector.numel() must be equal to size.");
PADDLE_MOBILE_ENFORCE((output->dims() == in_dims),
"output->dims() must be equal to in_dims.");
auto *input_data = input.data<float>();
auto *out_data = output->data<float>();
auto *vec_data = vector.data<float>();
for (int64_t i = 0; i < in_dims[0]; ++i) {
for (int64_t j = 0; j < size; ++j) {
out_data[i * size + j] = input_data[i * size + j] + vec_data[j];
}
}
}
};
template struct RowwiseAdd<CPU, float>;
template struct ClearTensor<CPU, float>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -38,6 +38,17 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1);
template <typename DeviceType, typename T>
struct ClearTensor {
void operator()(framework::Tensor *tensor);
};
template <typename DeviceType, typename T>
struct RowwiseAdd {
void operator()(const framework::Tensor &input, const framework::Tensor &vec,
framework::Tensor *output);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/sequence2batch.h"
#include <cstring>
#include "common/types.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename T>
class CopyMatrixRowsFunctor<CPU, T> {
public:
void operator()(const framework::Tensor& src, std::vector<size_t> index_lod,
framework::Tensor* dst, bool is_src_index) {
size_t* index = index_lod.data();
auto src_dims = src.dims();
auto dst_dims = dst->dims();
PADDLE_MOBILE_ENFORCE((src_dims.size() == 2UL),
"The src must be matrix with rank 2.");
PADDLE_MOBILE_ENFORCE((dst_dims.size() == 2UL),
"The dst must be matrix with rank 2.");
PADDLE_MOBILE_ENFORCE((src_dims[1] == dst_dims[1]),
"The width of src and dst must be same.");
auto height = dst_dims[0];
auto width = dst_dims[1];
auto* src_data = src.data<T>();
auto* dst_data = dst->data<T>();
for (int i = 0; i < height; ++i) {
if (is_src_index) {
memcpy(dst_data + i * width, src_data + index[i] * width,
width * sizeof(T));
} else {
memcpy(dst_data + index[i] * width, src_data + i * width,
width * sizeof(T));
}
}
}
};
template class CopyMatrixRowsFunctor<CPU, float>;
template class LoDTensor2BatchFunctor<CPU, float>;
template class Batch2LoDTensorFunctor<CPU, float>;
} // namespace math
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "framework/lod_tensor.h"
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename DeviceType, typename T>
class CopyMatrixRowsFunctor {
public:
// If is_src_index is true,
// copy the indexed rows of input src to the output dst.
// If is_src_index is false,
// copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index.
void operator()(const framework::Tensor& src, std::vector<size_t> index_lod,
framework::Tensor* dst, bool is_src_index);
};
template <typename DeviceType, typename T>
class LoDTensor2BatchFunctor {
// Calculate the length of each sequence and
// sort sequence index by the length.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
//
struct SeqInfo {
SeqInfo(int start, int length, int seq_idx)
: start(start), length(length), seq_idx(seq_idx) {}
int start;
int length;
int seq_idx;
};
public:
void operator()(const framework::LoDTensor& lod_tensor,
framework::LoDTensor* batch, bool is_cal_batch_lod,
bool is_reverse = false) {
if (!is_cal_batch_lod) {
auto lods = batch->lod();
PADDLE_MOBILE_ENFORCE(
(lods.size() > 2UL),
"The LoD of LoDTensor should inlcude at least 2-level "
"sequence information.");
PADDLE_MOBILE_ENFORCE(
(lods[1].size() == static_cast<size_t>(lod_tensor.dims()[0])),
"The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor<DeviceType, T> to_batch;
to_batch(lod_tensor, lods[1], batch, true);
return;
}
auto lods = lod_tensor.lod();
PADDLE_MOBILE_ENFORCE((lods.size() == 1UL),
"Only support one level sequence now.");
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
std::sort(seq_info.begin(), seq_info.end(),
[](SeqInfo a, SeqInfo b) { return a.length > b.length; });
// Calculate the start position of each batch.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// num_batch = 5,
// batchIndex = {b0, b1, b2, b3, b4}
// b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// batch_start_positions[0] = len(b0)
// batch_start_positions[1] = len(b0) + len(b1)
// batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// ...
// seq2batch_idx[12] = {4, 0, 9,
// 5, 1, 10,
// 6, 2, 11,
// 7, 3,
// 8}
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
// The num_batch represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
framework::LoD batch_lods;
batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
int num_batch = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods[2].resize(seq_info.size());
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (int n = 0; n < num_batch; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) {
int seq_len = seq_info[i].length;
int start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
batch_id++;
} else {
break;
}
}
batch_starts[n + 1] = static_cast<size_t>(batch_id);
}
size_t* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
seq_order[i] = seq_info[i].seq_idx;
}
batch->set_lod(batch_lods);
CopyMatrixRowsFunctor<DeviceType, T> to_batch;
to_batch(lod_tensor, batch_lods[1], batch, true);
}
};
template <typename DeviceType, typename T>
class Batch2LoDTensorFunctor {
public:
void operator()(const framework::LoDTensor& batch,
framework::LoDTensor* lod_tensor) {
auto in_lod = batch.lod();
PADDLE_MOBILE_ENFORCE(
(in_lod.size() > 2UL),
"The LoD of LoDTensor should inlcude at least 2-level "
"sequence information.");
PADDLE_MOBILE_ENFORCE(
(in_lod[1].size() == static_cast<size_t>(lod_tensor->dims()[0])),
"The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor<DeviceType, T> to_seq;
to_seq(batch, in_lod[1], lod_tensor, false);
}
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......@@ -73,6 +73,10 @@ struct DtypeTensorTrait<GPU_MALI> {
class OpParam {
protected:
template <typename T>
static T *InputH0From(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("H0", inputs, scope);
}
template <typename T>
static T *InputAlphaFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Alpha", inputs, scope);
......@@ -87,6 +91,33 @@ class OpParam {
static T *InputXFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("X", inputs, scope);
}
template <typename T>
static T *InputWFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("W", inputs, scope);
}
template <typename T>
static T *InputIdsFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Ids", inputs, scope);
}
template <typename T>
static T *InputEmissionFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetVarValue<T>("Emission", inputs, scope);
}
template <typename T>
static T *InputTransitionFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetVarValue<T>("Transition", inputs, scope);
}
template <typename T>
static T *InputLabelFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Label", inputs, scope);
}
template <typename T>
static T *InputXFrom1(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue1<T>("addX", inputs, scope);
......@@ -112,6 +143,10 @@ class OpParam {
return GetVarValue<T>("Bias", inputs, scope);
}
template <typename T>
static T *InputWeightFrom(const VariableNameMap &inputs, const Scope &scope) {
return GetVarValue<T>("Weight", inputs, scope);
}
template <typename T>
static T *InputVarianceFrom(const VariableNameMap &inputs,
const Scope &scope) {
return GetVarValue<T>("Variance", inputs, scope);
......@@ -166,6 +201,35 @@ class OpParam {
return GetMultiVarValue<T>("X", inputs, scope);
}
template <typename T>
static T *OutputBatchGateFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("BatchGate", outputs, scope);
}
template <typename T>
static T *OutputViterbiPathFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("ViterbiPath", outputs, scope);
}
template <typename T>
static T *OutputBatchResetHiddenPrevFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("BatchResetHiddenPrev", outputs, scope);
}
template <typename T>
static T *OutputBatchHiddenFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("BatchHidden", outputs, scope);
}
template <typename T>
static T *OutputHiddenFrom(const VariableNameMap &outputs,
const Scope &scope) {
return GetVarValue<T>("Hidden", outputs, scope);
}
template <typename T>
static T *OutputFrom(const VariableNameMap &outputs, const Scope &scope) {
return GetVarValue<T>("Output", outputs, scope);
......@@ -326,18 +390,18 @@ class ElementwiseAddParam : OpParam {
axis_ = GetAttr<int>("axis", attrs);
}
const RType *InputX() const { return input_x_; }
const GType *InputX() const { return input_x_; }
const RType *InputY() const { return input_y_; }
const GType *InputY() const { return input_y_; }
RType *Out() const { return out_; }
GType *Out() const { return out_; }
const int &Axis() const { return axis_; }
private:
RType *input_x_;
RType *input_y_;
RType *out_;
GType *input_x_;
GType *input_y_;
GType *out_;
int axis_;
#ifdef PADDLE_MOBILE_FPGA
......@@ -371,20 +435,20 @@ class MulParam : OpParam {
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
}
const RType *InputX() const { return input_x_; }
const GType *InputX() const { return input_x_; }
const RType *InputY() const { return input_y_; }
const GType *InputY() const { return input_y_; }
RType *Out() const { return out_; }
GType *Out() const { return out_; }
const int &XNumColDims() const { return x_num_col_dims_; }
const int &YNumColDims() const { return y_num_col_dims_; }
private:
RType *input_x_;
RType *input_y_;
RType *out_;
GType *input_x_;
GType *input_y_;
GType *out_;
int x_num_col_dims_;
int y_num_col_dims_;
};
......@@ -406,13 +470,13 @@ class ConcatParam : public OpParam {
vector<GType *> Inputs() const { return inputs_; }
RType *Out() const { return out_; }
GType *Out() const { return out_; }
const int &Axis() const { return axis_; }
private:
vector<GType *> inputs_;
RType *out_;
GType *out_;
int axis_;
};
#endif
......@@ -797,13 +861,13 @@ class FeedParam : public OpParam {
auto var = scope->Var("batch_size");
batch_size = var->GetValue<int>();
}
const RType *InputX() const { return input_x_; }
RType *Out() const { return out_; }
const GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
const int BatchSize() const { return batch_size; }
private:
RType *input_x_;
RType *out_;
GType *input_x_;
GType *out_;
int batch_size;
};
......@@ -853,6 +917,72 @@ class TransposeParam : public OpParam {
};
#endif
#ifdef LOOKUP_OP
template <typename Dtype>
class LookupParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
LookupParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_w_ = InputWFrom<GType>(inputs, scope);
input_ids_ = InputIdsFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
padding_idx_ = GetAttr<int64_t>("padding_idx", attrs);
}
const GType *InputW() const { return input_w_; }
const GType *InputIds() const { return input_ids_; }
GType *Out() const { return out_; }
int64_t PaddingIdx() const { return padding_idx_; }
private:
GType *input_w_;
GType *input_ids_;
GType *out_;
int64_t padding_idx_;
};
#endif
#ifdef CRF_OP
template <typename Dtype>
class CrfParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
// {G_OP_TYPE_CRF, {{"Emission", "Transition", "Label"}, {"ViterbiPath"}}},
CrfParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
// todo crf params
input_emission_ = InputEmissionFrom<GType>(inputs, scope);
input_transition_ = InputTransitionFrom<GType>(inputs, scope);
input_label_ = InputLabelFrom<GType>(inputs, scope);
output_viterbipath_ = OutputViterbiPathFrom<GType>(outputs, scope);
// padding_idx_ = GetAttr<int64_t>("padding_idx", attrs);
}
const GType *InputEmission() const { return input_emission_; }
const GType *InputTransition() const { return input_transition_; }
const GType *InputLabel() const { return input_label_; }
GType *outputVBP() const { return output_viterbipath_; }
// const RType *InputIds() const { return input_ids_; }
// RType *Out() const { return out_; }
// int64_t PaddingIdx() const { return padding_idx_; }
private:
GType *input_emission_;
GType *input_transition_;
GType *input_label_;
GType *output_viterbipath_;
// RType *input_ids_;
// RType *out_;
// int64_t padding_idx_;
};
#endif
#ifdef RESHAPE_OP
template <typename Dtype>
class ReshapeParam : public OpParam {
......@@ -1095,7 +1225,7 @@ class FusionFcParam : public OpParam {
y_num_col_dims_ = GetAttr<int>("y_num_col_dims", attrs);
axis_ = GetAttr<int>("axis", attrs);
}
const RType *InputX() const { return input_x_; }
const GType *InputX() const { return input_x_; }
#ifdef PADDLE_MOBILE_FPGA
RType *InputY() const { return input_y_; }
......@@ -1105,7 +1235,7 @@ class FusionFcParam : public OpParam {
const RType *InputZ() const { return input_z_; }
RType *Out() const { return out_; }
GType *Out() const { return out_; }
const int &XNumColDims() const { return x_num_col_dims_; }
......@@ -1114,10 +1244,10 @@ class FusionFcParam : public OpParam {
const int &Axis() const { return axis_; }
private:
RType *input_x_;
GType *input_x_;
RType *input_y_;
RType *input_z_;
RType *out_;
GType *out_;
int x_num_col_dims_;
int y_num_col_dims_;
int axis_;
......@@ -2062,5 +2192,65 @@ class ConvTransposeParam : public OpParam {
};
#endif
#ifdef GRU_OP
template <typename Dtype>
class GruParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
public:
/**
*
* @param inputs
* @param outputs
* @param attrs
* @param scope
* */
GruParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_input_ = InputFrom<GType>(inputs, scope);
input_h0_ = InputH0From<GType>(inputs, scope);
input_bias_ = InputBiasFrom<GType>(inputs, scope);
input_weight_ = InputWeightFrom<GType>(inputs, scope);
output_batch_gate_ = OutputBatchGateFrom<GType>(outputs, scope);
output_batch_reset_hidden_prev_ =
OutputBatchResetHiddenPrevFrom<GType>(outputs, scope);
output_batch_hidden_ = OutputBatchHiddenFrom<GType>(outputs, scope);
output_hidden_ = OutputHiddenFrom<GType>(outputs, scope);
activation_ = GetAttr<std::string>("activation", attrs);
gate_activation_ = GetAttr<std::string>("gate_activation", attrs);
is_reverse_ = GetAttr<bool>("is_reverse", attrs);
}
const GType *InputInput() const { return input_input_; }
const GType *InputWeight() const { return input_weight_; }
const GType *InputH0() const { return input_h0_; }
const GType *InputBias() const { return input_bias_; }
const std::string &Activation() const { return activation_; }
const std::string &GateActivation() const { return gate_activation_; }
const bool &IsReverse() const { return is_reverse_; }
GType *OutBatchGate() const { return output_batch_gate_; }
GType *OutBatchResetHiddenPrev() const {
return output_batch_reset_hidden_prev_;
}
GType *OutBatchHidden() const { return output_batch_hidden_; }
GType *OutHidden() const { return output_hidden_; }
private:
GType *input_input_;
GType *input_h0_;
GType *input_bias_;
GType *input_weight_;
GType *output_batch_gate_;
GType *output_batch_reset_hidden_prev_;
GType *output_batch_hidden_;
GType *output_hidden_;
std::string activation_;
std::string gate_activation_;
bool is_reverse_;
};
#endif
} // namespace operators
} // namespace paddle_mobile
......@@ -190,6 +190,14 @@ else ()
ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-add-bn-relu-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-nlp net/test_nlp.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-nlp paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gru-op operators/test_gru_op.cpp test_helper.h test_include.h)
target_link_libraries(test-gru-op paddle-mobile)
#add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(4);
auto time1 = time();
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(g_nlp, true, false, 1, true);
// auto isok = paddle_mobile.Load(std::string(g_nlp) + "/model",
// std::string(g_nlp) + "/params", false);
if (isok) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
// 1064 1603 644 699 2878 1219 867 1352 8 1 13 312 479
std::vector<int64_t> ids{1064, 1603, 644, 699, 2878, 1219, 867,
1352, 8, 1, 13, 312, 479};
paddle_mobile::framework::LoDTensor words;
auto size = static_cast<int>(ids.size());
paddle_mobile::framework::LoD lod{{0, ids.size()}};
DDim dims{size, 1};
words.Resize(dims);
words.set_lod(lod);
DLOG << "words lod : " << words.lod();
auto *pdata = words.mutable_data<int64_t>();
size_t n = words.numel() * sizeof(int64_t);
DLOG << "n :" << n;
memcpy(pdata, ids.data(), n);
DLOG << "words lod 22: " << words.lod();
auto time3 = time();
for (int i = 0; i < 1; ++i) {
auto vec_result = paddle_mobile.PredictLod(words);
DLOG << *vec_result;
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 1 << "ms"
<< std::endl;
}
return 0;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_include.h"
#include "operators/gru_op.h"
int main() {
paddle_mobile::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(g_nlp);
PADDLE_MOBILE_ENFORCE(program.originProgram != nullptr,
"program file read fail");
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::GruOp<paddle_mobile::CPU, float>>
executor(program, "gru");
return 0;
}
......@@ -33,6 +33,7 @@ static const char *g_mobilenet_detect = "../models/mobilenet-detect";
static const char *g_squeezenet = "../models/squeezenet";
static const char *g_googlenet = "../models/googlenet";
static const char *g_mobilenet = "../models/mobilenet";
static const char *g_nlp = "../models/nlp";
static const char *g_resnet_50 = "../models/resnet_50";
static const char *g_resnet = "../models/resnet";
static const char *g_googlenet_combine = "../models/googlenet_combine";
......
......@@ -111,7 +111,7 @@ if ("FPGAnets" IN_LIST NET)
set(FUSION_CONVBN_OP ON)
set(FUSION_CONVADD_OP ON)
set(FOUND_MATCH ON)
set(FOUND_MATCH ON)
endif()
......@@ -149,6 +149,9 @@ if(NOT FOUND_MATCH)
set(SLICE_OP ON)
set(DROPOUT_OP ON)
set(IM2SEQUENCE_OP ON)
set(LOOKUP_OP ON)
set(GRU_OP ON)
set(CRF_OP ON)
endif()
# option(BATCHNORM_OP "" ON)
......@@ -288,3 +291,15 @@ endif()
if (CONV_TRANSPOSE_OP)
add_definitions(-DCONV_TRANSPOSE)
endif()
if (LOOKUP_OP)
add_definitions(-DLOOKUP_OP)
endif()
if (GRU_OP)
add_definitions(-DGRU_OP)
endif()
if (CRF_OP)
add_definitions(-DCRF_OP)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册