未验证 提交 fbaf76ef 编写于 作者: R Ray Liu 提交者: GitHub

Merge pull request #1352 from hjchen2/ocr_ctc

Refactor predict api, optimize softmax, add top_k and cast operators
......@@ -22,6 +22,8 @@ const char *G_OP_TYPE_BATCHNORM = "batch_norm";
const char *G_OP_TYPE_BOX_CODER = "box_coder";
const char *G_OP_TYPE_CONCAT = "concat";
const char *G_OP_TYPE_ELEMENTWISE_ADD = "elementwise_add";
const char *G_OP_TYPE_ELEMENTWISE_SUB = "elementwise_sub";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_FILL_CONSTANT = "fill_constant";
const char *G_OP_TYPE_FUSION_CONV_ADD_RELU = "fusion_conv_add_relu";
const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU = "fusion_conv_add_prelu";
......@@ -67,8 +69,9 @@ const char *G_OP_TYPE_CRF = "crf_decoding";
const char *G_OP_TYPE_BILINEAR_INTERP = "bilinear_interp";
const char *G_OP_TYPE_FLATTEN = "flatten";
const char *G_OP_TYPE_SHAPE = "shape";
const char *G_OP_TYPE_ELEMENTWISE_MUL = "elementwise_mul";
const char *G_OP_TYPE_SUM = "sum";
const char *G_OP_TYPE_TOP_K = "top_k";
const char *G_OP_TYPE_CAST = "cast";
const char *G_OP_TYPE_QUANTIZE = "quantize";
const char *G_OP_TYPE_DEQUANTIZE = "dequantize";
......@@ -100,6 +103,8 @@ std::unordered_map<
{G_OP_TYPE_SIGMOID, {{"X"}, {"Out"}}},
{G_OP_TYPE_MUL, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_ADD, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_SUB, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_POOL2D, {{"X"}, {"Out"}}},
{G_OP_TYPE_BATCHNORM, {{"X"}, {"Y"}}},
{G_OP_TYPE_LRN, {{"X"}, {"Out"}}},
......@@ -142,7 +147,8 @@ std::unordered_map<
{G_OP_TYPE_SHAPE, {{"Input"}, {"Out"}}},
{G_OP_TYPE_CONV_TRANSPOSE, {{"Input"}, {"Output"}}},
{G_OP_TYPE_SUM, {{"X"}, {"Out"}}},
{G_OP_TYPE_ELEMENTWISE_MUL, {{"X", "Y"}, {"Out"}}},
{G_OP_TYPE_TOP_K, {{"X"}, {"Out", "Indices"}}},
{G_OP_TYPE_CAST, {{"X"}, {"Out"}}},
{G_OP_TYPE_QUANTIZE, {{"X"}, {"Out", "OutScale"}}},
{G_OP_TYPE_DEQUANTIZE, {{"X", "Scale"}, {"Out"}}},
{G_OP_TYPE_FUSION_DEQUANT_BN, {{"X", "Scale"}, {"Out"}}},
......
......@@ -112,6 +112,8 @@ extern const char *G_OP_TYPE_BATCHNORM;
extern const char *G_OP_TYPE_BOX_CODER;
extern const char *G_OP_TYPE_CONCAT;
extern const char *G_OP_TYPE_ELEMENTWISE_ADD;
extern const char *G_OP_TYPE_ELEMENTWISE_SUB;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_RELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_PRELU;
extern const char *G_OP_TYPE_FUSION_CONV_ADD_ADD_PRELU;
......@@ -149,7 +151,8 @@ extern const char *G_OP_TYPE_FUSION_CONV_BN;
extern const char *G_OP_TYPE_CONV_TRANSPOSE;
extern const char *G_OP_TYPE_PRELU;
extern const char *G_OP_TYPE_SUM;
extern const char *G_OP_TYPE_ELEMENTWISE_MUL;
extern const char *G_OP_TYPE_TOP_K;
extern const char *G_OP_TYPE_CAST;
extern const char *G_OP_TYPE_QUANTIZE;
extern const char *G_OP_TYPE_DEQUANTIZE;
......
......@@ -28,6 +28,10 @@ extern _PaddleMobile__Framework__Proto__VarType__Type ToDataType(
extern std::type_index ToTypeIndex(
_PaddleMobile__Framework__Proto__VarType__Type type);
inline _PaddleMobile__Framework__Proto__VarType__Type ToDataType(int type) {
return static_cast<_PaddleMobile__Framework__Proto__VarType__Type>(type);
}
template <typename Visitor>
inline void VisitDataType(_PaddleMobile__Framework__Proto__VarType__Type type,
Visitor visitor) {
......
此差异已折叠。
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "common/types.h"
#include "common/util.h"
......@@ -28,41 +29,29 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
template <typename Dtype = CPU, Precision P = Precision::FP32>
template <typename Device, typename T = float>
class Executor {
public:
typedef typename PrecisionTrait<P>::ptype Ptype;
// exector constructor
// @param program program converted from proto program in PaddlePaddle
// @param use_optimize bool whether use operator fusion to speed up or not
// @param loddable bool
Executor(const framework::Program<Dtype> program, int batch_size = 1,
const bool use_optimize = true, const bool loddable = false);
// predict with tensor input
// @param t input tensor to do prediction
// @return predicted tensor
std::shared_ptr<framework::Tensor> Predict(const framework::Tensor &t);
// predict with lod tensor input
// @param t input lod tensor to do prediction
// @return predicted lod tensor
std::shared_ptr<framework::LoDTensor> PredictLod(
const framework::LoDTensor &t);
// predict with vector input and dims
// @param input vector whose elements will be formed
// @param input lod tensor to do prediction
// @param dims vector whose elements will be formed
// @param input tensor shape
// @return vector which is flatted from predicted tensor
std::vector<Ptype> Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims);
Executor(const Program<Device> &program, int batch_size = 1,
const bool use_optimize = true, const bool lod_mode = false);
PMStatus Predict(const std::vector<std::pair<std::string, Tensor>> &inputs);
PMStatus Predict(
const std::vector<std::pair<std::string, LoDTensor>> &inputs);
std::vector<T> Predict(const std::vector<T> &input,
const std::vector<int64_t> &dims);
PMStatus Predict();
void SetInput(const Tensor &input, const std::string &var_name);
void SetInput(const LoDTensor &input, const std::string &var_name);
std::shared_ptr<LoDTensor> GetOutput(const std::string &var_name);
#ifdef PADDLE_MOBILE_FPGA
void InjectVariable(const framework::Tensor &t, std::string var_name);
void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult(int id = -1);
void InjectVariable(const Tensor &t, std::string var_name);
void FeedData(const Tensor &t);
std::shared_ptr<Tensor> FetchResult(int id = -1);
void Predict_From_To(int start = 0, int end = -1);
void Predict_From(int start);
void Predict_To(int end);
......@@ -70,26 +59,28 @@ class Executor {
protected:
Executor() = default;
std::shared_ptr<framework::Tensor> Predict(const framework::Tensor &t,
int block_id);
bool varInputMemory(const std::shared_ptr<framework::VarDesc> &var_desc,
framework::Variable *var,
framework::LoDTensor *tensor) const;
bool varInputMemory(const std::shared_ptr<VarDesc> &var_desc, Variable *var,
LoDTensor *tensor) const;
void InitMemory();
void InitCombineMemory();
void LoadMemory(void **data,
const std::shared_ptr<framework::VarDesc> var_desc,
framework::LoDTensor *tensor);
void LoadMemory(void **data, const std::shared_ptr<VarDesc> var_desc,
LoDTensor *tensor);
#ifdef PADDLE_MOBILE_CL
void LoadMemory(const framework::VarDesc var_desc, float *tensorInput,
char **data);
void LoadMemory(const VarDesc var_desc, float *tensorInput, char **data);
#endif
framework::Program<Dtype> program_;
int batch_size_ = 1;
std::shared_ptr<framework::ProgramDesc> to_predict_program_;
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<framework::OperatorBase<Dtype>>>>
ops_of_block_;
int batch_size_;
bool use_optimize_;
bool lod_mode_;
Program<Device> program_;
std::shared_ptr<ProgramDesc> program_desc_;
typedef std::shared_ptr<OperatorBase<Device>> OperatorBasePtr;
std::vector<std::vector<OperatorBasePtr>> ops_of_block_;
// operators list
std::vector<OperatorBasePtr> ops_list_;
#ifdef PADDLE_MOBILE_PROFILE
struct ProfInfo {
int tid = 0;
......@@ -97,8 +88,6 @@ class Executor {
uint64_t runEnd = 0UL;
};
#endif
bool use_optimize_ = false;
bool loddable_ = false;
};
} // namespace framework
......
......@@ -228,6 +228,12 @@ LOAD_FUSION_MATCHER(fusion_conv_bn);
#ifdef ELEMENTWISESUB_OP
LOAD_OP1(elementwise_sub, CPU)
#endif
#ifdef TOP_K_OP
LOAD_OP1(top_k, CPU)
#endif
#ifdef CAST_OP
LOAD_OP1(cast, CPU)
#endif
#ifdef QUANT_OP
LOAD_OP1(quantize, CPU);
#endif
......
......@@ -23,14 +23,8 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
/**
* muteandresize tensor as originProgramDesc and scope in loadParams
*
* @param originProgramDesc
* @param scope
*/
template <typename Dtype, Precision P>
void Loader<Dtype, P>::InitMemoryFromProgram(
template <typename Device, typename T>
void Loader<Device, T>::InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope) {
for (const auto &block : originProgramDesc.get()->Blocks()) {
......@@ -43,8 +37,6 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
tensor->Resize(make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
// PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
// dim[0] = 1;
if (dim.size() == 0) {
auto tensor = var->GetMutable<LoDTensor>();
framework::DDim dDim = {0};
......@@ -60,7 +52,7 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
}
}
} else {
// TODO(codeWorm): some.
// TODO(codeWorm)
}
}
}
......@@ -68,7 +60,7 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
#ifdef PADDLE_MOBILE_CL
template <>
void Loader<GPU_CL, Precision::FP32>::InitMemoryFromProgram(
void Loader<GPU_CL, float>::InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope) {
for (const auto &block : originProgramDesc.get()->Blocks()) {
......@@ -77,7 +69,6 @@ void Loader<GPU_CL, Precision::FP32>::InitMemoryFromProgram(
if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
if (var_desc->Persistable()) {
auto dim = var_desc->Tensor_desc().Dims();
// auto tensor = var->GetMutable<LoDTensor>();
auto cl_image = var->GetMutable<framework::CLImage>();
cl_image->Resize(make_ddim(dim));
} else {
......@@ -88,14 +79,13 @@ void Loader<GPU_CL, Precision::FP32>::InitMemoryFromProgram(
cl_image->Resize(make_ddim(dim));
}
} else {
// TODO(codeWorm): some.
// TODO(codeWorm)
}
}
}
}
template <>
const Program<GPU_CL, Precision::FP32>
Loader<GPU_CL, Precision::FP32>::LoadCombinedMemory(
const Program<GPU_CL, float> Loader<GPU_CL, float>::LoadCombinedMemory(
size_t read_size, const uint8_t *buf, size_t combined_params_len,
uint8_t *combined_params_buf, bool optimize, bool quantification) {
bool can_add_split = false;
......@@ -113,7 +103,7 @@ Loader<GPU_CL, Precision::FP32>::LoadCombinedMemory(
auto originProgramDesc = std::make_shared<ProgramDesc>(c_program);
Program<GPU_CL, Precision::FP32> program;
Program<GPU_CL, float> program;
program.combined = true;
program.originProgram = originProgramDesc;
program.quantification = quantification;
......@@ -145,16 +135,16 @@ Loader<GPU_CL, Precision::FP32>::LoadCombinedMemory(
/**
* fusion and print someinfos
* @tparam Dtype
* @tparam Device
* @tparam P
* @param optimize
* @param can_add_split
* @param program
* @param originProgramDesc
*/
template <typename Dtype, Precision P>
template <typename Device, typename T>
void FusionAndPrintInfos(
bool optimize, bool can_add_split, Program<Dtype, P> *program,
bool optimize, bool can_add_split, Program<Device, T> *program,
const std::shared_ptr<ProgramDesc> &originProgramDesc) {
if (optimize) {
ProgramOptimize program_optimize;
......@@ -193,22 +183,22 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) {
return cur_len;
}
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(const std::string &dirname,
bool optimize,
bool quantification,
bool can_add_split) {
template <typename Device, typename T>
const Program<Device, T> Loader<Device, T>::Load(const std::string &dirname,
bool optimize,
bool quantification,
bool can_add_split) {
auto program = this->LoadProgram(dirname + "/__model__", optimize,
quantification, can_add_split);
program.model_path = dirname;
return program;
}
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize,
bool quantification) {
template <typename Device, typename T>
const Program<Device, T> Loader<Device, T>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize,
bool quantification) {
auto program = this->LoadProgram(model_path, optimize, quantification);
program.para_path = para_path;
......@@ -217,8 +207,8 @@ const Program<Dtype, P> Loader<Dtype, P>::Load(const std::string &model_path,
return program;
}
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
template <typename Device, typename T>
const Program<Device, T> Loader<Device, T>::LoadProgram(
const std::string &model_path, bool optimize, bool quantification,
bool can_add_split) {
std::string model_filename = model_path;
......@@ -237,7 +227,7 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
//
auto originProgramDesc = std::make_shared<ProgramDesc>(c_program);
Program<Dtype, P> program;
Program<Device, T> program;
program.originProgram = originProgramDesc;
program.quantification = quantification;
program.combined_params_len = 0;
......@@ -254,8 +244,8 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
return program;
}
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::LoadCombinedMemory(
template <typename Device, typename T>
const Program<Device, T> Loader<Device, T>::LoadCombinedMemory(
size_t read_size, const uint8_t *buf, size_t combined_params_len,
uint8_t *combined_params_buf, bool optimize, bool quantification) {
bool can_add_split = false;
......@@ -273,7 +263,7 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadCombinedMemory(
auto originProgramDesc = std::make_shared<ProgramDesc>(c_program);
Program<Dtype, P> program;
Program<Device, T> program;
program.combined = true;
program.originProgram = originProgramDesc;
program.quantification = quantification;
......@@ -289,13 +279,13 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadCombinedMemory(
return program;
}
template class Loader<CPU, Precision::FP32>;
template class Loader<CPU, float>;
template class Loader<FPGA, Precision::FP32>;
template class Loader<FPGA, float>;
template class Loader<GPU_MALI, Precision::FP32>;
template class Loader<GPU_MALI, float>;
template class Loader<GPU_CL, Precision::FP32>;
template class Loader<GPU_CL, float>;
} // namespace framework
} // namespace paddle_mobile
......@@ -22,39 +22,39 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
template <typename Dtype = CPU, Precision P = Precision::FP32>
template <typename Device = CPU, typename T = float>
class Loader {
public:
/*
* @b load separate format fluid model
* @b 加载分开形式的 fluid 模型
* @b 加载分开存储的fluid模型
* */
const Program<Dtype, P> Load(const std::string &dirname,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
const Program<Device, T> Load(const std::string &dirname,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
/*
* @b load combine format fluid mode
* @b 加载结合在一起格式的模型
* @b 加载统一存储的fluid模型
* */
const Program<Dtype, P> Load(const std::string &model_path,
const std::string &para_path,
bool optimize = false,
bool quantification = false);
const Program<Device, T> Load(const std::string &model_path,
const std::string &para_path,
bool optimize = false,
bool quantification = false);
const Program<Dtype, P> LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
size_t combined_params_len,
uint8_t *combined_params_buf,
bool optimize = false,
bool quantification = false);
const Program<Device, T> LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
size_t combined_params_len,
uint8_t *combined_params_buf,
bool optimize = false,
bool quantification = false);
private:
const Program<Dtype, P> LoadProgram(const std::string &model_path,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
const Program<Device, T> LoadProgram(const std::string &model_path,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
void InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
......
......@@ -16,12 +16,12 @@ limitations under the License. */
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "tensor.h"
#include "tensor_util.h"
#include "framework/tensor.h"
#include "framework/tensor_util.h"
namespace paddle_mobile {
namespace framework {
/*
......@@ -202,5 +202,29 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor);
void DeserializeFromStream(std::istream &is, LoDTensor *tensor);
#ifdef PADDLE_MOBILE_DEBUG
inline Print &operator<<(Print &printer, const LoDTensor &tensor) {
printer << " dims: " << tensor.dims() << "\n";
int stride = tensor.numel() / 20;
stride = stride > 0 ? stride : 1;
#ifndef PADDLE_MOBILE_FPGA
for (int i = 0; i < tensor.numel(); i += stride) {
if (tensor.type() == typeid(float)) {
printer << tensor.data<float>()[i] << " ";
} else if (tensor.type() == typeid(int32_t)) {
printer << tensor.data<int32_t>()[i] << " ";
} else if (tensor.type() == typeid(int64_t)) {
printer << tensor.data<int64_t>()[i] << " ";
} else if (tensor.type() == typeid(int8_t)) {
printer << static_cast<int>(tensor.data<int8_t>()[i]) << " ";
} else if (tensor.type() == typeid(int32_t)) {
printer << tensor.data<int32_t>()[i] << " ";
}
}
#endif // PADDLE_MOBILE_FPGA
return printer;
}
#endif // PADDLE_MOBILE_DEBUG
} // namespace framework
} // namespace paddle_mobile
......@@ -14,16 +14,15 @@ limitations under the License. */
#pragma once
#include <string>
#include "common/types.h"
#include "framework/program/program_desc.h"
#include "framework/scope.h"
#include <string>
namespace paddle_mobile {
namespace framework {
template <typename Dtype, Precision P = Precision::FP32>
template <typename Device, typename T = float>
class Program {
public:
std::shared_ptr<ProgramDesc> originProgram;
......
......@@ -26,6 +26,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
class Scope {
public:
Scope() = default;
......
......@@ -148,8 +148,8 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
"Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
......@@ -162,7 +162,7 @@ class Tensor : public TensorBase {
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s ,requested:%s",
"Tensor holds the wrong type, it holds %s, requested %s",
this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<const T *>(
......@@ -226,7 +226,6 @@ inline Print &operator<<(Print &printer, const Tensor &tensor) {
}
}
#endif
return printer;
}
......
......@@ -18,17 +18,17 @@
namespace paddle_mobile {
template <typename Dtype, Precision P>
PaddleMobilePredictor<Dtype, P>::PaddleMobilePredictor(
template <typename Device, typename T>
PaddleMobilePredictor<Device, T>::PaddleMobilePredictor(
const PaddleMobileConfig &config) {
PADDLE_MOBILE_ENFORCE(Init(config) == true,
"paddle mobile predictor init failed!");
config_ = config;
}
template <typename Dtype, Precision P>
bool PaddleMobilePredictor<Dtype, P>::Init(const PaddleMobileConfig &config) {
paddle_mobile_.reset(new PaddleMobile<Dtype, P>());
template <typename Device, typename T>
bool PaddleMobilePredictor<Device, T>::Init(const PaddleMobileConfig &config) {
paddle_mobile_.reset(new PaddleMobile<Device, T>());
#ifdef PADDLE_MOBILE_CL
paddle_mobile_->SetCLPath(config.cl_path);
#endif
......@@ -52,8 +52,8 @@ bool PaddleMobilePredictor<Dtype, P>::Init(const PaddleMobileConfig &config) {
paddle_mobile_->SetThreadNum(config.thread_num);
return true;
}
template <typename Dtype, Precision P>
bool PaddleMobilePredictor<Dtype, P>::Run(
template <typename Device, typename T>
bool PaddleMobilePredictor<Device, T>::Run(
const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data, int batch_size) {
if (inputs.empty()) {
......@@ -78,12 +78,12 @@ bool PaddleMobilePredictor<Dtype, P>::Run(
framework::Tensor input_tensor;
input_tensor.Resize(ddim);
int input_length = framework::product(ddim);
typedef typename PrecisionTrait<P>::ptype PType;
auto input_ptr = input_tensor.mutable_data<PType>();
auto input_ptr = input_tensor.mutable_data<T>();
memcpy(input_ptr, static_cast<PType *>(input.data.data()),
input_length * sizeof(PType));
auto output_tensor = paddle_mobile_->Predict(input_tensor);
memcpy(input_ptr, static_cast<T *>(input.data.data()),
input_length * sizeof(T));
paddle_mobile_->Predict(input_tensor);
auto output_tensor = paddle_mobile_->Fetch();
if (output_data->empty()) {
LOG(kLOG_ERROR) << "At least one output should be set with tensors' names.";
......@@ -99,18 +99,18 @@ bool PaddleMobilePredictor<Dtype, P>::Run(
output.shape.push_back(static_cast<int>(d));
}
if (output.data.length() < output_length * sizeof(PType)) {
output.data.Resize(output_length * sizeof(PType));
if (output.data.length() < output_length * sizeof(T)) {
output.data.Resize(output_length * sizeof(T));
}
memcpy(output.data.data(), output_tensor->template data<PType>(),
output_length * sizeof(PType));
memcpy(output.data.data(), output_tensor->template data<T>(),
output_length * sizeof(T));
return true;
}
template <typename Dtype, Precision P>
PaddleMobilePredictor<Dtype, P>::~PaddleMobilePredictor() {
template <typename Device, typename T>
PaddleMobilePredictor<Device, T>::~PaddleMobilePredictor() {
paddle_mobile_->Clear();
}
......@@ -122,13 +122,13 @@ CreatePaddlePredictor<PaddleMobileConfig, PaddleEngineKind::kPaddleMobile>(
std::unique_ptr<PaddlePredictor> x;
if (config.precision == PaddleMobileConfig::FP32) {
if (config.device == PaddleMobileConfig::kCPU) {
x.reset(new PaddleMobilePredictor<CPU, Precision::FP32>(config));
x.reset(new PaddleMobilePredictor<CPU, float>(config));
} else if (config.device == PaddleMobileConfig::kFPGA) {
x.reset(new PaddleMobilePredictor<FPGA, Precision::FP32>(config));
x.reset(new PaddleMobilePredictor<FPGA, float>(config));
} else if (config.device == PaddleMobileConfig::kGPU_MALI) {
x.reset(new PaddleMobilePredictor<GPU_MALI, Precision::FP32>(config));
x.reset(new PaddleMobilePredictor<GPU_MALI, float>(config));
} else if (config.device == PaddleMobileConfig::kGPU_CL) {
x.reset(new PaddleMobilePredictor<GPU_CL, Precision::FP32>(config));
x.reset(new PaddleMobilePredictor<GPU_CL, float>(config));
} else {
LOG(kLOG_ERROR) << "unsupport device type!";
return nullptr;
......
......@@ -29,7 +29,7 @@ limitations under the License. */
namespace paddle_mobile {
template <typename Dtype = CPU, Precision P = Precision::FP32>
template <typename Device = CPU, typename T = float>
class PaddleMobilePredictor : public PaddlePredictor {
public:
PaddleMobilePredictor() = delete;
......@@ -43,7 +43,7 @@ class PaddleMobilePredictor : public PaddlePredictor {
~PaddleMobilePredictor() override;
private:
std::unique_ptr<PaddleMobile<Dtype, P>> paddle_mobile_;
std::unique_ptr<PaddleMobile<Device, T>> paddle_mobile_;
bool Init(const PaddleMobileConfig& config);
PaddleMobileConfig config_;
......
......@@ -48,7 +48,7 @@
@interface PaddleMobileCPU()
{
paddle_mobile::PaddleMobile<paddle_mobile::CPU, paddle_mobile::Precision::FP32> *pam_;
paddle_mobile::PaddleMobile<paddle_mobile::CPU, float> *pam_;
BOOL loaded_;
}
@end
......@@ -59,7 +59,7 @@ static std::mutex shared_mutex;
- (instancetype)init {
if (self = [super init]) {
pam_ = new paddle_mobile::PaddleMobile<paddle_mobile::CPU, paddle_mobile::Precision::FP32>();
pam_ = new paddle_mobile::PaddleMobile<paddle_mobile::CPU, float>();
}
return self;
}
......@@ -220,7 +220,8 @@ static std::mutex shared_mutex;
memcpy(input_ptr, input,
numel * sizeof(float));
std::shared_ptr<paddle_mobile::framework::Tensor> output = pam_->Predict(input_tensor);
pam_->Predict(input_tensor);
std::shared_ptr<paddle_mobile::framework::Tensor> output = pam_->Fetch();
float *output_pointer = new float[output->numel()];
......
......@@ -16,21 +16,23 @@ limitations under the License. */
#include "paddle_mobile_jni.h"
#include <cmath>
#include <string>
#include <vector>
#include "common/log.h"
#include "framework/tensor.h"
#include "io/paddle_mobile.h"
#ifdef ENABLE_EXCEPTION
#include "common/enforce.h"
#endif
#ifdef __cplusplus
extern "C" {
#endif
namespace paddle_mobile {
namespace jni {
using framework::DDim;
using framework::Program;
using framework::Tensor;
......@@ -200,7 +202,8 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
for (int i = 0; i < length; i++) {
input_ptr[i] = dataPointer[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
getPaddleMobileInstance()->Predict(input);
auto output = getPaddleMobileInstance()->Fetch();
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
......@@ -233,7 +236,8 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictImage(
for (int i = 0; i < length; i++) {
input_ptr[i] = dataPointer[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
getPaddleMobileInstance()->Predict(input);
auto output = getPaddleMobileInstance()->Fetch();
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
......@@ -328,7 +332,8 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
for (int i = 0; i < length; i++) {
input_ptr[i] = matrix[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
getPaddleMobileInstance()->Predict(input);
auto output = getPaddleMobileInstance()->Fetch();
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
......@@ -363,7 +368,8 @@ JNIEXPORT jfloatArray JNICALL Java_com_baidu_paddle_PML_predictYuv(
for (int i = 0; i < length; i++) {
input_ptr[i] = matrix[i];
}
auto output = getPaddleMobileInstance()->Predict(input);
getPaddleMobileInstance()->Predict(input);
auto output = getPaddleMobileInstance()->Fetch();
count = output->numel();
result = env->NewFloatArray(count);
env->SetFloatArrayRegion(result, 0, count, output->data<float>());
......@@ -399,7 +405,8 @@ Java_com_baidu_paddle_PML_predictLod(JNIEnv *env, jclass thiz, jlongArray buf) {
auto *pdata = words.mutable_data<int64_t>();
size_t n = words.numel() * sizeof(int64_t);
memcpy(pdata, ids.data(), n);
auto vec_result = paddle_mobile.PredictLod(words);
paddle_mobile.Predict(words);
auto vec_result = paddle_mobile.Fetch();
int count = vec_result->numel();
jlongArray result = NULL;
ANDROIDLOGE("predict nlp size %d", count);
......
......@@ -13,81 +13,81 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "io/paddle_mobile.h"
#include <utility>
#include "common/common.h"
#ifdef PADDLE_MOBILE_CL
#include <CL/cl.h>
#include "framework/cl/cl_tensor.h"
#endif
#include "common/common.h"
#include "operators/math/gemm.h"
namespace paddle_mobile {
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::SetThreadNum(int num) {
template <typename Device, typename T>
void PaddleMobile<Device, T>::SetThreadNum(int num) {
#ifdef _OPENMP
omp_set_num_threads(num);
#endif
}
template <typename Dtype, Precision P>
bool PaddleMobile<Dtype, P>::Load(const std::string &dirname, bool optimize,
bool quantification, int batch_size,
bool loddable) {
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &dirname,
bool optimize, bool quantification,
int batch_size, bool loddable) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Dtype, P>>();
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
LOG(kLOG_INFO) << "loader inited";
}
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Dtype, P>>(
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(dirname, optimize, quantification), batch_size, optimize,
loddable);
} else {
LOG(kLOG_INFO) << "executor inited";
}
return true;
return PMSuccess;
}
template <typename Dtype, Precision P>
bool PaddleMobile<Dtype, P>::Load(const std::string &model_path,
const std::string &para_path, bool optimize,
bool quantification, int batch_size,
bool loddable) {
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize, bool quantification,
int batch_size, bool loddable) {
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Dtype, P>>();
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
LOG(kLOG_INFO) << "loader inited";
}
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Dtype, P>>(
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->Load(model_path, para_path, optimize, quantification),
batch_size, optimize, loddable);
} else {
LOG(kLOG_INFO) << "executor inited";
}
return true;
return PMSuccess;
}
template <typename Dtype, Precision P>
bool PaddleMobile<Dtype, P>::LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
size_t combined_params_len,
uint8_t *combined_params_buf) {
template <typename Device, typename T>
bool PaddleMobile<Device, T>::LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
size_t combined_params_len,
uint8_t *combined_params_buf) {
int batch_size = 1;
bool optimise = true;
bool quantification = false;
if (loader_.get() == nullptr) {
loader_ = std::make_shared<framework::Loader<Dtype, P>>();
loader_ = std::make_shared<framework::Loader<Device, T>>();
} else {
LOG(kLOG_INFO) << "loader inited";
}
if (executor_.get() == nullptr) {
executor_ = std::make_shared<framework::Executor<Dtype, P>>(
executor_ = std::make_shared<framework::Executor<Device, T>>(
loader_->LoadCombinedMemory(model_len, model_buf, combined_params_len,
combined_params_buf, optimise,
quantification),
......@@ -96,38 +96,76 @@ bool PaddleMobile<Dtype, P>::LoadCombinedMemory(size_t model_len,
LOG(kLOG_INFO) << "executor inited";
}
return true;
return PMSuccess;
}
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Predict(const framework::Tensor &input) {
std::vector<std::pair<std::string, framework::Tensor>> inputs;
inputs.push_back(std::make_pair("feed", input));
return this->Predict(inputs);
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> PaddleMobile<Dtype, P>::Predict(
const framework::Tensor &t) {
return executor_->Predict(t);
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Predict(const framework::LoDTensor &input) {
std::vector<std::pair<std::string, framework::LoDTensor>> inputs;
inputs.push_back(std::make_pair("feed", input));
return this->Predict(inputs);
}
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Predict(
const std::vector<std::pair<std::string, framework::Tensor>> &inputs) {
return executor_->Predict(inputs);
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> PaddleMobile<Dtype, P>::PredictLod(
const framework::LoDTensor &t) {
return executor_->PredictLod(t);
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Predict(
const std::vector<std::pair<std::string, framework::LoDTensor>> &inputs) {
return executor_->Predict(inputs);
}
template <typename Dtype, Precision P>
std::vector<typename PaddleMobile<Dtype, P>::Ptype>
PaddleMobile<Dtype, P>::Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims) {
template <typename Device, typename T>
std::vector<T> PaddleMobile<Device, T>::Predict(
const std::vector<T> &input, const std::vector<int64_t> &dims) {
return executor_->Predict(input, dims);
}
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::Clear() {
template <typename Device, typename T>
PMStatus PaddleMobile<Device, T>::Predict() {
return executor_->Predict();
}
template <typename Device, typename T>
void PaddleMobile<Device, T>::Feed(const framework::Tensor &input,
const std::string &var_name) {
executor_->SetInput(input, var_name);
}
template <typename Device, typename T>
void PaddleMobile<Device, T>::Feed(const framework::LoDTensor &input,
const std::string &var_name) {
executor_->SetInput(input, var_name);
}
typedef std::shared_ptr<framework::LoDTensor> LoDTensorPtr;
template <typename Device, typename T>
LoDTensorPtr PaddleMobile<Device, T>::Fetch(const std::string &var_name) {
return executor_->GetOutput(var_name);
}
template <typename Device, typename T>
void PaddleMobile<Device, T>::Clear() {
executor_ = nullptr;
loader_ = nullptr;
}
template <typename Dtype, Precision P>
double PaddleMobile<Dtype, P>::GetPredictTime() {}
template <typename Device, typename T>
double PaddleMobile<Device, T>::GetPredictTime() {}
#ifdef PADDLE_MOBILE_CPU
template <>
double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() {
double PaddleMobile<CPU, float>::GetPredictTime() {
int m = 32;
int n = 224 * 224;
int k = 27;
......@@ -148,7 +186,8 @@ double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() {
for (int i = 0; i < k * n; ++i) {
b[i] = t1 + rand() % t2; // NOLINT
}
paddle_mobile::operators::math::Gemm gemm;
operators::math::Gemm gemm;
auto time1 = paddle_mobile::time();
gemm.Sgemm(m, n, k, static_cast<float>(1), a, lda, b, ldb,
static_cast<float>(0), c, ldc, false,
......@@ -162,57 +201,51 @@ double PaddleMobile<CPU, Precision::FP32>::GetPredictTime() {
}
#endif
template <typename Dtype, Precision P>
PaddleMobile<Dtype, P>::~PaddleMobile() {
executor_ = nullptr;
loader_ = nullptr;
}
#ifdef PADDLE_MOBILE_FPGA
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::InjectVariable(const framework::Tensor &t,
std::string var_name) {
template <typename Device, typename T>
void PaddleMobile<Device, T>::InjectVariable(const framework::Tensor &t,
std::string var_name) {
executor_->InjectVariable(t, var_name);
}
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::FeedData(const framework::Tensor &t) {
template <typename Device, typename T>
void PaddleMobile<Device, T>::FeedData(const framework::Tensor &t) {
executor_->FeedData(t);
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> PaddleMobile<Dtype, P>::FetchResult(int id) {
template <typename Device, typename T>
std::shared_ptr<framework::Tensor> PaddleMobile<Device, T>::FetchResult(
int id) {
return executor_->FetchResult(id);
}
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::Predict_From_To(int start, int end) {
template <typename Device, typename T>
void PaddleMobile<Device, T>::Predict_From_To(int start, int end) {
executor_->Predict_From_To(start, end);
}
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::Predict_From(int start) {
template <typename Device, typename T>
void PaddleMobile<Device, T>::Predict_From(int start) {
executor_->Predict_From(start);
}
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::Predict_To(int end) {
template <typename Device, typename T>
void PaddleMobile<Device, T>::Predict_To(int end) {
executor_->Predict_To(end);
}
#endif
#ifdef PADDLE_MOBILE_CL
static std::mutex lc;
template <typename Dtype, Precision P>
void PaddleMobile<Dtype, P>::SetCLPath(std::string path) {
template <typename Device, typename T>
void PaddleMobile<Device, T>::SetCLPath(std::string path) {
std::lock_guard<std::mutex> lock(lc);
if (framework::CLEngine::Instance()->GetCLPath() == "") {
framework::CLEngine::Instance()->setClPath(path);
}
}
template <>
double PaddleMobile<GPU_CL, Precision::FP32>::GetPredictTime() {
double PaddleMobile<GPU_CL, float>::GetPredictTime() {
cl_int status;
cl_uint nPlatform;
clGetPlatformIDs(0, NULL, &nPlatform);
......@@ -410,8 +443,8 @@ double PaddleMobile<GPU_CL, Precision::FP32>::GetPredictTime() {
return -1;
}
}
template <typename Dtype, Precision P>
int PaddleMobile<Dtype, P>::readText(
template <typename Device, typename T>
int PaddleMobile<Device, T>::readText(
const char *kernelPath,
char **pcode) { // 读取文本文件放入 pcode,返回字符串长度
FILE *fp;
......@@ -440,13 +473,11 @@ int PaddleMobile<Dtype, P>::readText(
fclose(fp);
return size + 1;
}
#endif
template class PaddleMobile<CPU, Precision::FP32>;
template class PaddleMobile<FPGA, Precision::FP32>;
template class PaddleMobile<GPU_MALI, Precision::FP32>;
template class PaddleMobile<GPU_CL, Precision::FP32>;
template class PaddleMobile<CPU, float>;
template class PaddleMobile<FPGA, float>;
template class PaddleMobile<GPU_MALI, float>;
template class PaddleMobile<GPU_CL, float>;
} // namespace paddle_mobile
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <utility>
#include <vector>
#ifdef _OPENMP
#include <omp.h>
......@@ -32,43 +33,52 @@ limitations under the License. */
namespace paddle_mobile {
template <typename Dtype = CPU, Precision P = Precision::FP32>
template <typename Device, typename T = float>
class PaddleMobile {
typedef typename PrecisionTrait<P>::ptype Ptype;
public:
PaddleMobile() {
#ifndef PADDLE_MOBILE_CL
bool is_gpu = std::is_same<DeviceType<kGPU_CL>, Dtype>::value;
PADDLE_MOBILE_ENFORCE(!is_gpu,
"Not Enable GPU in CmakeList but run gpu codes ");
bool is_gpu = std::is_same<DeviceType<kGPU_CL>, Device>::value;
PADDLE_MOBILE_ENFORCE(!is_gpu, "Please recompile with GPU_CL is on");
#endif
}
bool Load(const std::string &dirname, bool optimize = false,
bool quantification = false, int batch_size = 1,
bool loddable = false);
~PaddleMobile() {}
PMStatus Load(const std::string &dirname, const bool optimize = false,
const bool quantification = false, const int batch_size = 1,
const bool lod = false);
PMStatus Load(const std::string &model_path, const std::string &para_path,
const bool optimize = false, const bool quantification = false,
const int batch_size = 1, const bool lod = false);
PMStatus Predict(const framework::Tensor &input);
PMStatus Predict(const framework::LoDTensor &input);
bool Load(const std::string &model_path, const std::string &para_path,
bool optimize = false, bool quantification = false,
int batch_size = 1, bool loddable = false);
PMStatus Predict(
const std::vector<std::pair<std::string, framework::Tensor>> &inputs);
PMStatus Predict(
const std::vector<std::pair<std::string, framework::LoDTensor>> &inputs);
std::shared_ptr<framework::Tensor> Predict(const framework::Tensor &t);
std::vector<T> Predict(const std::vector<T> &input,
const std::vector<int64_t> &dims);
PMStatus Predict();
std::shared_ptr<framework::Tensor> PredictLod(const framework::LoDTensor &t);
void Feed(const framework::LoDTensor &input, const std::string &var_name);
void Feed(const framework::Tensor &input, const std::string &var_name);
std::vector<Ptype> Predict(const std::vector<Ptype> &input,
const std::vector<int64_t> &dims);
typedef std::shared_ptr<framework::LoDTensor> LoDTensorPtr;
LoDTensorPtr Fetch(const std::string &var_name);
LoDTensorPtr Fetch() { return Fetch("fetch"); }
bool LoadCombinedMemory(size_t model_len, const uint8_t *model_buf,
size_t combined_params_len,
uint8_t *combined_params_buf);
void SetThreadNum(int num);
void SetThreadNum(int count);
void Clear();
double GetPredictTime();
~PaddleMobile();
#ifdef PADDLE_MOBILE_FPGA
void InjectVariable(const framework::Tensor &t, std::string var_name);
void FeedData(const framework::Tensor &t);
......@@ -79,15 +89,15 @@ class PaddleMobile {
#endif
#ifdef PADDLE_MOBILE_CL
public:
public: // NOLINT
void SetCLPath(std::string cl_path);
int readText(const char *kernelPath,
char **pcode); // 读取文本文件放入 pcode,返回字符串长度
#endif
private:
std::shared_ptr<framework::Loader<Dtype, P>> loader_;
std::shared_ptr<framework::Executor<Dtype, P>> executor_;
std::shared_ptr<framework::Loader<Device, T>> loader_;
std::shared_ptr<framework::Executor<Device, T>> executor_;
};
} // namespace paddle_mobile
......@@ -14,10 +14,12 @@ limitations under the License. */
#include "io/paddle_test_inference_api.h"
#include "io/paddle_mobile.h"
namespace paddle_mobile {
template <typename Dtype, Precision P>
double PaddleTester<Dtype, P>::CaculatePredictTime(std::string *cl_path) {
PaddleMobile<Dtype, P> paddle_mobile;
template <typename Device, typename T>
double PaddleTester<Device, T>::CaculatePredictTime(std::string *cl_path) {
PaddleMobile<Device, T> paddle_mobile;
#ifdef PADDLE_MOBILE_CL
if (cl_path) {
paddle_mobile.SetCLPath(*cl_path);
......@@ -26,10 +28,10 @@ double PaddleTester<Dtype, P>::CaculatePredictTime(std::string *cl_path) {
#endif
return paddle_mobile.GetPredictTime();
}
template class PaddleTester<CPU, Precision::FP32>;
template class PaddleTester<FPGA, Precision::FP32>;
template class PaddleTester<GPU_MALI, Precision::FP32>;
template class PaddleTester<CPU, float>;
template class PaddleTester<FPGA, float>;
template class PaddleTester<GPU_MALI, float>;
template class PaddleTester<GPU_CL, Precision::FP32>;
template class PaddleTester<GPU_CL, float>;
} // namespace paddle_mobile
......@@ -20,10 +20,13 @@ limitations under the License. */
*/
#pragma once
#include "common/types.h"
#include "string"
namespace paddle_mobile {
template <typename Dtype, Precision P = Precision::FP32>
template <typename Device, typename T = float>
class PaddleTester {
public:
double CaculatePredictTime(std::string *cl_path = nullptr);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CAST_OP
#include "operators/cast_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void CastOp<DeviceType, T>::InferShape() const {
const auto &dims = this->param_.input_->dims();
this->param_.output_->Resize(dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(cast, ops::CastOp);
#endif
#endif // CAST_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CAST_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class CastOp : public framework::OperatorWithKernel<
DeviceType, CastParam<DeviceType>,
operators::CastKernel<DeviceType, T>> {
public:
CastOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, CastParam<DeviceType>,
operators::CastKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // CAST_OP
......@@ -33,4 +33,4 @@ namespace ops = paddle_mobile::operators;
REGISTER_OPERATOR_CPU(dequantize, ops::DequantizeOp);
#endif
#endif
#endif // DEQUANT_OP
......@@ -44,4 +44,4 @@ class DequantizeOp
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // DEQUANT_OP
......@@ -25,12 +25,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class FillConstantOp : public framework::OperatorBase<DeviceType> {
public:
FillConstantOp(const string &type, const VariableNameMap &inputs,
FillConstantOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
......@@ -58,7 +57,7 @@ class FillConstantOp : public framework::OperatorBase<DeviceType> {
tensor->Resize(framework::make_ddim(param_.Shape()));
tensor->mutable_data(framework::ToTypeIndex(data_type));
math::set_constant(tensor, value);
math::SetConstant(tensor, value);
}
void Init() {}
......
......@@ -14,19 +14,15 @@ limitations under the License. */
#ifdef GRU_OP
#include "operators/gru_op.h"
#include <vector>
#include "common/enforce.h"
#include "operators/gru_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void GruOp<Dtype, T>::InferShape() const {
auto lod_size = this->param_.InputInput()->lod().size();
PADDLE_MOBILE_ENFORCE((lod_size == 1),
"Current LoD only supports one dimension.");
auto input_dims = this->param_.InputInput()->dims();
auto weight_dims = this->param_.InputWeight()->dims();
int input_size = input_dims[1];
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef IM2SEQUENCE_OP
#include "operators/im2sequence_op.h"
#include <vector>
namespace paddle_mobile {
namespace operators {
......@@ -29,20 +30,16 @@ int Im2SequenceOutputSize(int input_size, int kernel, int padding_1,
template <typename Dtype, typename T>
void Im2SequenceOp<Dtype, T>::InferShape() const {
auto in_x_dims = this->param_.Input()->dims();
const std::vector<int> &kernels = this->param_.Kernels();
const std::vector<int> &strides = this->param_.Strides();
std::vector<int> paddings = this->param_.Paddings();
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < strides.size(); ++i) {
output_shape.push_back(Im2SequenceOutputSize(in_x_dims[i + 2], kernels[i],
paddings[i], paddings[i + 2],
strides[i]));
}
framework::DDim ddim = framework::make_ddim(output_shape);
this->param_.Output()->Resize(ddim);
}
......@@ -54,9 +51,5 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(im2sequence, ops::Im2SequenceOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
#endif // IM2SEQUENCE_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef CAST_OP
#include <algorithm>
#include <vector>
#include "framework/data_type.h"
#include "operators/kernel/kernels.h"
namespace paddle_mobile {
namespace operators {
template <typename InT>
struct CastOutOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
CastOutOpFunctor(const framework::Tensor* in, framework::Tensor* out)
: in_(in), out_(out) {}
template <typename OutT>
void apply() const {
const InT* input = in_->data<InT>();
OutT* output = out_->mutable_data<OutT>();
size_t numel = in_->numel();
for (int i = 0; i < numel; ++i) {
output[i] = static_cast<OutT>(input[i]);
}
}
};
struct CastOpFunctor {
const framework::Tensor* in_;
framework::Tensor* out_;
int output_type_;
CastOpFunctor(const framework::Tensor* in, framework::Tensor* out,
const int output_type)
: in_(in), out_(out), output_type_(output_type) {}
template <typename InT>
void apply() const {
framework::VisitDataType(framework::ToDataType(output_type_),
CastOutOpFunctor<InT>(in_, out_));
}
};
template <>
bool CastKernel<CPU, float>::Init(CastParam<CPU>* param) {
return true;
}
template <>
void CastKernel<CPU, float>::Compute(const CastParam<CPU>& param) {
const Tensor* input = param.input_;
Tensor* output = param.output_;
framework::VisitDataType(framework::ToDataType(param.input_type_),
CastOpFunctor(input, output, param.output_type_));
}
} // namespace operators
} // namespace paddle_mobile
#endif // CAST_OP
......@@ -55,10 +55,9 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param->Input()->dims()[2] <= 140 /* refered from ncnn */) {
param->ExecMode() = ConvParam<CPU>::EXEC_WINOGRAD3X3_FLOAT;
// transform weight
framework::Tensor transformed_weight;
operators::math::winograd_transform_weight<8, 3>(*param->Filter(),
&transformed_weight);
framework::TensorCopy(transformed_weight, param->Filter());
param->transformed_filter_ = new framework::Tensor;
operators::math::winograd_transform_weight<8, 3>(
*param->Filter(), param->transformed_filter_);
#endif
} else {
param->ExecMode() = ConvParam<CPU>::EXEC_GEMM_FLOAT;
......
......@@ -30,8 +30,8 @@ bool DequantizeKernel<CPU, float>::Init(DequantizeParam<CPU> *param) {
template <>
void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
const Tensor *input = param.input_;
Tensor *output = param.output_;
const LoDTensor *input = param.input_;
LoDTensor *output = param.output_;
float activation_scale = param.activation_scale_->data<float>()[0];
float weight_scale = param.weight_scale_;
const int32_t *x = input->data<const int32_t>();
......@@ -72,6 +72,7 @@ void DequantizeKernel<CPU, float>::Compute(const DequantizeParam<CPU> &param) {
for (size_t i = 0; i < size; ++i) {
y[i] = x[i] * scale;
}
output->set_lod(input->lod());
}
} // namespace operators
......
......@@ -29,12 +29,6 @@ template <>
void GruKernel<CPU, float>::Compute(const GruParam<CPU> &param) {
GruCompute<float>(param);
param.OutHidden()->set_lod(param.InputInput()->lod());
// DLOG << "________________" << param.OutHidden()->dims();
// DLOG << "________________" << param.OutHidden()->numel();
// auto *hiden_data = param.OutHidden()->data<float>();
// for (int64_t i = 0; i < 10; i++) {
// DLOG << "****************" << hiden_data[i];
// }
}
template class GruKernel<CPU, float>;
......
......@@ -186,8 +186,8 @@ bool QuantizeKernel<CPU, float>::Init(QuantizeParam<CPU> *param) {
template <>
void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
const Tensor *input = param.input_;
Tensor *output = param.output_;
const LoDTensor *input = param.input_;
LoDTensor *output = param.output_;
Tensor *output_scale = param.online_scale_;
float max_abs = 0.f;
if (param.offline_) {
......@@ -212,6 +212,7 @@ void QuantizeKernel<CPU, float>::Compute(const QuantizeParam<CPU> &param) {
LOG(kLOG_ERROR) << "round type is not supported.";
break;
}
output->set_lod(input->lod());
}
} // namespace operators
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TOP_K_OP
#include <algorithm>
#include <iostream>
#include <vector>
#include "operators/kernel/kernels.h"
namespace paddle_mobile {
namespace operators {
template <>
bool TopKKernel<CPU, float>::Init(TopKParam<CPU> *param) {
return true;
}
template <>
void TopKKernel<CPU, float>::Compute(const TopKParam<CPU> &param) {
const Tensor *input = param.input_;
Tensor *output = param.output_;
Tensor *indices = param.indices_;
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
int64_t *indices_data = indices->mutable_data<int64_t>();
framework::DDim input_dims = input->dims();
const size_t row = framework::product(
framework::slice_ddim(input_dims, 0, input_dims.size() - 1));
const size_t col = input_dims[input_dims.size() - 1];
#pragma omp parallel for
for (size_t i = 0; i < row; i++) {
std::vector<std::pair<float, size_t>> vec(col);
const float *input_ptr = input_data + i * col;
float *output_ptr = output_data + i * param.k_;
int64_t *indices_ptr = indices_data + i * param.k_;
for (size_t j = 0; j < col; j++) {
vec[j] = std::move(std::pair<float, size_t>(input_ptr[j], j));
}
std::partial_sort(
vec.begin(), vec.begin() + param.k_, vec.end(),
[](const std::pair<float, size_t> &l,
const std::pair<float, size_t> &r) { return l.first > r.first; });
for (int j = 0; j < param.k_; ++j) {
output_ptr[j] = vec[j].first;
indices_ptr[j] = static_cast<int64_t>(vec[j].second);
}
}
}
} // namespace operators
} // namespace paddle_mobile
#endif // TOP_K_OP
......@@ -18,283 +18,63 @@ limitations under the License. */
#include <cmath>
#include "operators/op_param.h"
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#include <arm_neon.h>
#endif // __ARM_NEON__
namespace paddle_mobile {
namespace operators {
template <typename P>
void BatchnormCompute(const BatchNormParam<CPU> &param) {
const Tensor *input_x = param.InputX();
auto input_x_ptr = input_x->data<float>();
const auto &x_dims = input_x->dims();
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const int stride0 = C * H * W;
const int stride1 = H * W;
const int stride2 = W;
Tensor *out = param.OutputY();
auto out_ptr = out->mutable_data<float>();
const float epsilon = param.Epsilon();
const Tensor *mean = param.InputMean();
const Tensor *variance = param.InputVariance();
const Tensor *scale = param.InputScale();
const Tensor *bias = param.InputBias();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
// Tensor inv_std;
// auto inv_std_ptr = inv_std.mutable_data<float>(make_ddim({C}));
PADDLE_MOBILE_ENFORCE(C == variance->numel(),
"C must equal to variance.numel()");
int HXW = H * W;
#if __ARM_NEON
#if __aarch64__
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr = new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
const float *mean_ptr = param.InputMean()->data<float>();
const float *variance_ptr = param.InputVariance()->data<float>();
const float *scale_ptr = param.InputScale()->data<float>();
const float *bias_ptr = param.InputBias()->data<float>();
const framework::Tensor *input = param.InputX();
const float *input_ptr = input->data<float>();
framework::Tensor *output = param.OutputY();
float *output_ptr = output->mutable_data<float>();
size_t spatial_size = output->dims()[2] * output->dims()[3];
int channels = output->dims()[1];
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < output->dims()[0]; ++batch) {
for (int c = 0; c < channels; ++c) {
float inv_scale = 1.f / (std::sqrt(variance_ptr[c] + epsilon));
float bias = bias_ptr[c] - inv_scale * scale_ptr[c] * mean_ptr[c];
float scale = inv_scale * scale_ptr[c];
size_t offset = (batch * channels + c) * spatial_size;
const float *x = input_ptr + offset;
float *y = output_ptr + offset;
size_t remain = spatial_size;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = spatial_size >> 4;
remain = spatial_size & 0xF;
float32x4_t __scale = vdupq_n_f32(scale);
float32x4_t __bias = vdupq_n_f32(bias);
for (int k = 0; k < loop; ++k, x += 16, y += 16) {
float32x4_t r0 = vld1q_f32(x);
float32x4_t r1 = vld1q_f32(x + 4);
float32x4_t r2 = vld1q_f32(x + 8);
float32x4_t r3 = vld1q_f32(x + 12);
r0 = vmlaq_f32(__bias, __scale, r0);
r1 = vmlaq_f32(__bias, __scale, r1);
r2 = vmlaq_f32(__bias, __scale, r2);
r3 = vmlaq_f32(__bias, __scale, r3);
vst1q_f32(y, r0);
vst1q_f32(y + 4, r1);
vst1q_f32(y + 8, r2);
vst1q_f32(y + 12, r3);
}
}
}
delete[] inv_std_ptr;
#else
if (HXW > 32) {
int NXC = N * C;
float *inv_std_ptr = new float[NXC * 4];
float *volatile new_scale_ptr = new float[NXC * 4];
float *volatile new_bias_ptr = new float[NXC * 4];
/// std = (var + epsilon).sqrt();
/// inv_std = 1 / std;
for (int i = 0; i < C * 4; i += 4) {
int index = i / 4;
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[index] + epsilon), 0.5));
inv_std_ptr[i + 1] = inv_std_ptr[i];
inv_std_ptr[i + 2] = inv_std_ptr[i];
inv_std_ptr[i + 3] = inv_std_ptr[i];
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[index];
new_scale_ptr[i + 1] = new_scale_ptr[i];
new_scale_ptr[i + 2] = new_scale_ptr[i];
new_scale_ptr[i + 3] = new_scale_ptr[i];
new_bias_ptr[i] =
bias_ptr[index] - mean_ptr[index] * inv_std_ptr[i] * scale_ptr[index];
new_bias_ptr[i + 1] = new_bias_ptr[i];
new_bias_ptr[i + 2] = new_bias_ptr[i];
new_bias_ptr[i + 3] = new_bias_ptr[i];
}
for (int j = C * 4; j < NXC * 4; ++j) {
new_scale_ptr[j] = new_scale_ptr[j - C * 4];
new_bias_ptr[j] = new_bias_ptr[j - C * 4];
}
asm volatile(
"subs %[N], %[N], #1 \n\t"
"blt end_n_%= \n\t"
"loop_n_%=: \n\t"
"subs %[C], %[C], #1 \n\t"
"blt end_c_%= \n\t"
"loop_c_%=: \n\t"
"vld1.32 {q9}, [%[new_scale_ptr]]! \n\t"
"vld1.32 {q10}, [%[new_bias_ptr]]! \n\t"
"mov r6, %[HXW] \n\t"
"subs r6, r6, #32 \n\t"
"blt end_hw_%= \n\t"
"loop_hw_%=: \n\t"
"vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t"
"vmul.f32 q1, q1, q9 \n\t"
"vmul.f32 q2, q2, q9 \n\t"
"vmul.f32 q3, q3, q9 \n\t"
"vmul.f32 q4, q4, q9 \n\t"
"vmul.f32 q5, q5, q9 \n\t"
"vmul.f32 q6, q6, q9 \n\t"
"vmul.f32 q7, q7, q9 \n\t"
"vmul.f32 q8, q8, q9 \n\t"
"vadd.f32 q1, q1, q10 \n\t"
"vadd.f32 q2, q2, q10 \n\t"
"vadd.f32 q3, q3, q10 \n\t"
"vadd.f32 q4, q4, q10 \n\t"
"vadd.f32 q5, q5, q10 \n\t"
"vadd.f32 q6, q6, q10 \n\t"
"vadd.f32 q7, q7, q10 \n\t"
"vadd.f32 q8, q8, q10 \n\t"
"vst1.32 {q1, q2}, [%[out_ptr]]! \n\t"
"vst1.32 {q3, q4}, [%[out_ptr]]! \n\t"
"vst1.32 {q5, q6}, [%[out_ptr]]! \n\t"
"vst1.32 {q7, q8}, [%[out_ptr]]! \n\t"
"subs r6, r6, #32 \n\t"
"bge loop_hw_%= \n\t"
"end_hw_%=: \n\t"
"cmp r6, #0 \n\t"
"bge end_remainder_%= \n\t"
"mov r5, #4 \n\t"
"mul r6, r6, r5 \n\t"
"add %[input_x_ptr], %[input_x_ptr], r6 \n\t"
"vld1.32 {q1, q2}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q3, q4}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q5, q6}, [%[input_x_ptr]]! \n\t"
"vld1.32 {q7, q8}, [%[input_x_ptr]]! \n\t"
"vmul.f32 q1, q1, q9 \n\t"
"vmul.f32 q2, q2, q9 \n\t"
"vmul.f32 q3, q3, q9 \n\t"
"vmul.f32 q4, q4, q9 \n\t"
"vmul.f32 q5, q5, q9 \n\t"
"vmul.f32 q6, q6, q9 \n\t"
"vmul.f32 q7, q7, q9 \n\t"
"vmul.f32 q8, q8, q9 \n\t"
"vadd.f32 q1, q1, q10 \n\t"
"vadd.f32 q2, q2, q10 \n\t"
"vadd.f32 q3, q3, q10 \n\t"
"vadd.f32 q4, q4, q10 \n\t"
"vadd.f32 q5, q5, q10 \n\t"
"vadd.f32 q6, q6, q10 \n\t"
"vadd.f32 q7, q7, q10 \n\t"
"vadd.f32 q8, q8, q10 \n\t"
"add %[out_ptr], %[out_ptr], r6 \n\t"
"vst1.32 {q1, q2}, [%[out_ptr]]! \n\t"
"vst1.32 {q3, q4}, [%[out_ptr]]! \n\t"
"vst1.32 {q5, q6}, [%[out_ptr]]! \n\t"
"vst1.32 {q7, q8}, [%[out_ptr]]! \n\t"
"end_remainder_%=: \n\t"
"subs %[C], %[C], #1 \n\t"
"bge loop_c_%= \n\t"
"end_c_%=: \n\t"
"subs %[N], %[N], #1 \n\t"
"bge loop_n_%= \n\t"
"end_n_%=: \n\t"
:
: [input_x_ptr] "r"(input_x_ptr), [out_ptr] "r"(out_ptr),
[new_scale_ptr] "r"(new_scale_ptr), [new_bias_ptr] "r"(new_bias_ptr),
[N] "r"(N), [C] "r"(C), [HXW] "r"(HXW)
: "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "r5", "r6");
delete[] inv_std_ptr;
delete[] new_scale_ptr;
delete[] new_bias_ptr;
} else {
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr =
new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] =
bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
}
}
}
delete[] inv_std_ptr;
}
#endif
#else
float *inv_std_ptr = new float[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
Tensor new_scale;
auto new_scale_ptr = new_scale.mutable_data<float>(framework::make_ddim({C}));
Tensor new_bias;
auto new_bias_ptr = new_bias.mutable_data<float>(framework::make_ddim({C}));
/// ((x - est_mean) * (inv_var) * scale + bias equal to
/// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
{
for (int n = 0; n < N; n++) {
for (int h = 0; h < H; h++) {
int tmp_index = n * stride0 + i * stride1 + h * stride2;
for (int w = 0; w < W; w++) {
int index = tmp_index + w;
out_ptr[index] =
input_x_ptr[index] * new_scale_ptr[i] + new_bias_ptr[i];
}
}
#endif // __ARM_NEON__
for (int k = 0; k < remain; ++k) {
y[k] = scale * x[k] + bias;
}
}
}
delete[] inv_std_ptr;
#endif
}
} // namespace operators
......
......@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDADDPRELU_OP
#pragma once
#include <string>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
......@@ -115,20 +116,7 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) {
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias1_slice = bias1_batch.Slice(g * out_step, (g + 1) * out_step);
float *biase_data1 = bias1_slice.data<float>();
// int n = bias1_slice.dims()[0];
// int m = bias1_slice.dims()[1];
// for(int i=0;i<n*m;i++){
// if(biase_data1[i]!=0)
// DLOG<<biase_data1[i]<<",yangfei";
// }
// math::matmul<float>(filter_slice, false, col_matrix,
// false,
// static_cast<float>(1),
// &out_slice,
// static_cast<float>(1), true,
// biase_data);
math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
p, mode, biase_data, biase_data1);
}
}
......@@ -137,4 +125,4 @@ void ConvAddAddPReluCompute(const FusionConvAddAddPReluParam<CPU> &param) {
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // FUSION_CONVADDADDPRELU_OP
......@@ -107,7 +107,7 @@ void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<float, float>(filter_slice, false, col_matrix, false,
math::MatMul<float, float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), false, biase_data);
}
......
......@@ -25,6 +25,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
......@@ -105,12 +106,13 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
Tensor Bias;
......@@ -126,9 +128,6 @@ void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
param.Filter()->dims()[2] == 3 && param.Strides()[0] == 2) {
// math::DepthwiseConvAddBNRelu3x3s2p1(param.Input(), param.Filter(),
// param.Output(), param.NewScale(),
// param.NewBias(), 1);
math::DepthwiseConvAddBNRelu3x3s2p1v2(param.Input(), param.Filter(),
param.Output(), param.NewScale(),
param.NewBias(), true);
......
......@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDPRELU_OP
#pragma once
#include <string>
#include <vector>
#include "operators/math/conv_func.h"
#include "operators/math/im2col.h"
......@@ -30,8 +31,6 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
// DLOG<<"yangfei";
// DLOG<<bias.dims();
int axis = param.Axis();
Tensor *output = param.Output();
float *biase_data = bias.data<float>();
......@@ -112,13 +111,7 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
// math::matmul<float>(filter_slice, false, col_matrix,
// false,
// static_cast<float>(1),
// &out_slice,
// static_cast<float>(1), true,
// biase_data);
math::matmulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
math::MatMulWithPRelu(filter_slice, false, col_matrix, false, &out_slice,
p, mode, biase_data, nullptr);
}
}
......@@ -127,4 +120,4 @@ void ConvAddPReluCompute(const FusionConvAddPReluParam<CPU> &param) {
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // FUSION_CONVADDPRELU_OP
......@@ -112,7 +112,7 @@ void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype, Otype>(filter_slice, false, col_matrix, false, alpha,
math::MatMul<Itype, Otype>(filter_slice, false, col_matrix, false, alpha,
&out_slice, beta, true, bias_data);
}
}
......
......@@ -106,7 +106,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmul<Itype, Otype>(filter_slice, false, col_matrix, false,
math::MatMul<Itype, Otype>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), false,
static_cast<Otype *>(nullptr));
......@@ -117,7 +117,7 @@ inline void GemmConv(const ConvParam<CPU> &param) {
template <int tile, int kernel>
inline void WinogradConv3x3(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
const Tensor *filter = param.Filter();
const Tensor *filter = param.transformed_filter_;
Tensor *output = param.Output();
output->mutable_data<float>();
int batch_size = input->dims()[0];
......
......@@ -108,10 +108,10 @@ void ConvBNAddReluBasic(const FusionConvBNAddReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
Tensor bias_data = bias_batch.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true, &new_scale,
&new_bias, g, bias_data.data<float>());
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(1), true, &new_scale, &new_bias, g,
bias_data.data<float>());
}
}
}
......
......@@ -107,9 +107,9 @@ void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
......
......@@ -93,7 +93,7 @@ void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
math::matmul<P, P>(filter_slice, true, in_slice, false,
math::MatMul<P, P>(filter_slice, true, in_slice, false,
static_cast<P>(1.0), &col_matrix, static_cast<P>(0.0));
if (data_dim == 2U) {
col2im(col, dilations, strides,
......
......@@ -106,9 +106,9 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
// gemm
Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step);
math::matmulWithBn<float>(
filter_slice, false, col_matrix, false, static_cast<float>(1),
&out_slice, static_cast<float>(0), true, &new_scale, &new_bias, g);
math::MatMulWithBn(filter_slice, false, col_matrix, false,
static_cast<float>(1), &out_slice,
static_cast<float>(0), true, &new_scale, &new_bias, g);
}
}
}
......
......@@ -26,18 +26,12 @@ namespace paddle_mobile {
namespace operators {
template <typename T>
struct AddFunctor {
inline T operator()(T a, T b) const { return a + b; }
};
template <typename P>
void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *Out = param.Out();
Out->mutable_data<float>();
inline void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const framework::Tensor *input_x = param.InputX();
const framework::Tensor *input_y = param.InputY();
framework::Tensor *Out = param.Out();
int axis = param.Axis();
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
const auto &x_dims = input_x->dims();
const auto &y_dims = input_y->dims();
/// axis = -1 represent the last dimensions.
......@@ -57,18 +51,20 @@ void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const float *bias_data = input_y->data<float>();
const float *input_data = input_x->data<float>();
float *output_data = Out->mutable_data<float>();
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
#pragma omp parallel for
for (int j = 0; j < channels; ++j) {
size_t offset = (i * channels + j) * elementwise_num;
const float *input = input_data + offset;
const float *bias = bias_data + j;
const float bias = bias_data[j];
float *output = output_data + offset;
int remain = elementwise_num;
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
int loop = elementwise_num >> 0x4;
int remain = elementwise_num & 0xF;
remain = elementwise_num & 0xF;
for (int k = 0; k < loop; ++k) {
float32x4_t rb = vdupq_n_f32(*bias);
float32x4_t rb = vdupq_n_f32(bias);
float32x4_t r0 = vld1q_f32(input);
float32x4_t r1 = vld1q_f32(input + 4);
float32x4_t r2 = vld1q_f32(input + 8);
......@@ -84,15 +80,12 @@ void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
input += 16;
output += 16;
}
#endif
for (int k = 0; k < remain; ++k) {
output[k] = input[k] + *bias;
output[k] = input[k] + bias;
}
}
}
#else
ElementwiseComputeEx<AddFunctor<float>, float>(input_x, input_y, axis,
AddFunctor<float>(), Out);
#endif
}
template class ElementwiseAddKernel<CPU, float>;
......
......@@ -57,7 +57,7 @@ void FusionFcCompute(const FusionFcParam<CPU> &param) {
for (int i = 0; i < out_dim[0]; i++) {
memory::Copy(out_data + i * classes, input_z_data, sizeof(Otype) * classes);
}
math::matmul<Itype, Otype>(x_matrix, false, y_matrix, false,
math::MatMul<Itype, Otype>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out, static_cast<float>(1),
false);
}
......
......@@ -25,18 +25,16 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename DeviceType, typename T>
template <typename Device, typename T>
inline void ReorderInitState(const framework::Tensor& src,
std::vector<size_t> index_lod,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<DeviceType, T> row_shuffle;
math::CopyMatrixRowsFunctor<Device, T> row_shuffle;
dst->mutable_data<T>(src.dims());
row_shuffle(src, index_lod, dst, indexed_src);
}
template <typename P>
template <typename T>
void GruCompute(const GruParam<CPU>& param) {
auto* input = param.InputInput();
auto* h0 = param.InputH0();
......@@ -57,8 +55,6 @@ void GruCompute(const GruParam<CPU>& param) {
bool is_reverse = param.IsReverse();
math::LoDTensor2BatchFunctor<CPU, float> to_batch;
to_batch(*input, batch_gate, true, is_reverse);
// math::ClearTensor<CPU, float> clearTensor;
// clearTensor(batch_gate);
if (bias) {
math::RowwiseAdd<CPU, float> add_bias;
add_bias(*batch_gate, *bias, batch_gate);
......@@ -68,7 +64,7 @@ void GruCompute(const GruParam<CPU>& param) {
gru_value.gate_weight = const_cast<float*>(weight_data);
gru_value.state_weight =
const_cast<float*>(weight_data + 2 * frame_size * frame_size);
Tensor ordered_h0;
framework::Tensor ordered_h0;
std::vector<size_t> order(batch_gate->lod()[2]);
if (h0) {
// Since the batch computing for GRU reorders the input sequences
......@@ -87,9 +83,10 @@ void GruCompute(const GruParam<CPU>& param) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend); // BUG
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
framework::Tensor gate_t = batch_gate->Slice(bstart, bend);
framework::Tensor reset_hidden_prev_t =
batch_reset_hidden_prev->Slice(bstart, bend);
framework::Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.output_value = hidden_t.data<float>();
gru_value.gate_value = gate_t.data<float>();
gru_value.reset_output_value = reset_hidden_prev_t.data<float>();
......@@ -105,7 +102,6 @@ void GruCompute(const GruParam<CPU>& param) {
}
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // GRU_OP
......@@ -19,40 +19,6 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
// 1、如果x,y维度都是2维,
// x = [[1,2], y = [[5,6],
// [3,4]] [7,8]]
// 运算结果为正常矩阵相乘。结果 out =
// [[1*5+2*7,1*6+2*8],[3*5+4*7, 3*6+4*8]]
//
// 2、如果x的维度大于2或者y的维度大于2,x的维度(2,3,4) ,y的维度(4,1,2)
// x = [[[1,2,3,4],
// [2,3,4,5],
// [3,4,5,6]],
// [[1,2,3,4],
// [2,3,4,5],
// [3,4,5,6]]]
// y = [[[1,2]],
// [[3,4]],
// [[5,6]],
// [[7,8]]]
// 需要借助x_num_col_dims和y_num_col_dims将x和y的维度转换为2维
// 从模型中读到参数,x_num_col_dims = 2,y_num_col_dims = 1,左开右闭
// (1) 将x = (2,3,4)的index [0,x_num_col_dims)部分2,3相乘,得到6,
// [x_num_col_dims,xdim.size())部分4相乘,得到4,
// 将Tensor x的dims重写成(6,4)
// (2) 将y = (4,1,2)的index [0,y_num_col_dims)部分4相乘,得到4,
// [y_num_col_dims,ydim.size())部分1,2相乘,得到2,
// 将Tensor y的dims重写成(4,2)
// 并不影响x,y在内存中的分布。
// x = [[1,2,3,4], y = [[1,2],
// [2,3,4,5], [3,4],
// [3,4,5,6], 矩阵乘法 [5,6],
// [1,2,3,4], [7,8]]
// [2,3,4,5],
// [3,4,5,6]]
// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)
template <typename P>
void MulCompute(const MulParam<CPU> &param) {
const Tensor *input_x = param.InputX();
......@@ -73,12 +39,12 @@ void MulCompute(const MulParam<CPU> &param) {
}
if (param.InputX()->type() == typeid(int8_t)) {
out->mutable_data<int32_t>();
math::matmul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
math::MatMul<int8_t, int32_t>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
} else {
out->mutable_data<float>();
math::matmul<float, float>(x_matrix, false, y_matrix, false,
math::MatMul<float, float>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out,
static_cast<float>(0));
}
......
......@@ -294,11 +294,6 @@ void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
}
}
}
// framework::LoD lod;
// lod.emplace_back(batch_starts);
//
// outs->set_lod(lod);
}
} // namespace operators
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "operators/kernel/feed_kernel.h"
#include "framework/cl/cl_tensor.h"
namespace paddle_mobile {
namespace operators {
......@@ -43,8 +44,8 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
const int Stride2 = out_C * out_H * out_W;
const int Stride1 = out_H * out_W;
const int Stride0 = out_W;
CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
framework::CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
input_cl_tensor.Resize(input->dims());
cl_mem inputBuffer = input_cl_tensor.mutable_with_data<float>(input_data);
......
......@@ -94,27 +94,20 @@ void FusionFcCompute(const FusionFcParam<GPU_CL> &param, cl_context context,
memory::Copy(out_data + i * classes, input_z_data, sizeof(float) * classes);
}
// for (int i = 0; i < out->numel(); i++) {
// DLOG << out_data[i];
// }
// bias_data的维度和out的维度一致
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1), false);
math::MatMul<float, float>(x_matrix, false, y_matrix, false,
static_cast<float>(1), out, static_cast<float>(1),
false);
out_image->InitEmptyImage(context, commandQueue, out->dims());
framework::TensorToCLImage(out, out_image, context, commandQueue, kernel1);
DLOG << *out;
delete (input_x);
delete (input_y);
delete (input_z);
delete (out);
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
// out->Resize(out_dim);
// }
}
template <>
void FusionFcKernel<GPU_CL, float>::Compute(
const FusionFcParam<GPU_CL> &param) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVBN_OP
#include "operators/kernel/conv_bn_kernel.h"
#include <cmath>
namespace paddle_mobile {
namespace operators {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef FUSION_CONVBNRELU_OP
#include "operators/kernel/conv_bn_relu_kernel.h"
#include <cmath>
namespace paddle_mobile {
namespace operators {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "framework/operator.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
#define DECLARE_KERNEL(KernelClass, KernelParam) \
template <typename DeviceType, typename T> \
class KernelClass \
: public framework::OpKernelBase<DeviceType, KernelParam<DeviceType>> { \
public: \
bool Init(KernelParam<DeviceType> *param); \
void Compute(const KernelParam<DeviceType> &param); \
};
#ifdef TOP_K_OP
DECLARE_KERNEL(TopKKernel, TopKParam)
#endif // TOP_K_OP
#ifdef CAST_OP
DECLARE_KERNEL(CastKernel, CastParam)
#endif // CAST_OP
} // namespace operators
} // namespace paddle_mobile
......@@ -61,7 +61,7 @@ void FusionFcKernel<GPU_MALI, float>::Compute(
for (int i = 0; i < out->numel(); i++) {
DLOG << out_data[i];
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
math::MatMul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(1));
PADDLE_MOBILE_ENFORCE(out_dim.size() == 2, " out_dim.size must be 2.");
// if (out_dim.size() != 2) {
......
......@@ -44,7 +44,7 @@ void MulKernel<GPU_MALI, float>::Compute(const MulParam<GPU_MALI> &param) {
if (out_dim.size() != 2) {
out->Resize({x_matrix.dims()[0], y_matrix.dims()[1]});
}
math::matmul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
math::MatMul<float>(x_matrix, false, y_matrix, false, static_cast<float>(1),
out, static_cast<float>(0));
if (out_dim.size() != 2) {
out->Resize(out_dim);
......
......@@ -38,7 +38,11 @@ limitations under the License. */
*
* (this is the zlib license)
*/
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
#pragma once
#include <arm_neon.h>
#define c_inv_mant_mask ~0x7f800000u
......@@ -316,11 +320,11 @@ static inline float32x4_t cos_ps(float32x4_t x) {
static inline float32x4_t div_ps(float32x4_t a, float32x4_t b) {
float32x4_t reciprocal = vrecpeq_f32(b);
reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
// reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal);
return vmulq_f32(a, reciprocal);
}
static inline float32x4_t pow_ps(float32x4_t a, float32x4_t b) {
// pow(x, m) = exp(m * log(x))
return exp_ps(vmulq_f32(b, log_ps(a)));
}
#endif // __ARM_NEON__
......@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/math/math_function.h"
#include <cstring>
#include <string>
#include "common/enforce.h"
#include "framework/data_type.h"
#include "framework/tensor.h"
#include "operators/math/gemm.h"
......@@ -35,13 +35,13 @@ struct TensorSetConstant {
float value_;
};
void set_constant(framework::Tensor *tensor, float value) {
void SetConstant(framework::Tensor *tensor, float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
TensorSetConstant(tensor, value));
}
template <>
void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
void MatMul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, float *bias) {
......@@ -50,20 +50,19 @@ void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int M = dim_out[0];
int N = dim_out[1];
int K = (!trans_a) ? dim_a[1] : dim_a[0];
Gemm gemm;
if (trans_a) {
framework::Tensor matrix_trans;
int numel = matrix_a.numel();
int m = matrix_a.dims()[0];
int n = matrix_a.dims()[1];
float *tmp = (float *)(matrix_a.data<float>()); // NOLINT
float *a = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
float *a = matrix_trans.mutable_data<float>(matrix_a.dims());
int index = 0;
for (int j = 0; j < n; j++) {
for (int i = 0; i < m; i++) {
......@@ -72,7 +71,6 @@ void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
}
#ifdef _OPENMP
gemm.Sgemm_omp(M, N, K, alpha, a, K, matrix_b.data<float>(), N, beta,
matrix_out->data<float>(), N, relu, bias);
#else
......@@ -92,19 +90,18 @@ void matmul<float, float>(const framework::Tensor &matrix_a, bool trans_a,
}
}
template <>
void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out, float beta,
bool relu, framework::Tensor *new_scale,
framework::Tensor *new_bias, int group, float *bias) {
void MatMulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group, float *bias) {
Gemm gemm;
auto dim_a = matrix_a.dims();
auto dim_b = matrix_b.dims();
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int M = dim_out[0];
int N = dim_out[1];
......@@ -122,7 +119,7 @@ void matmulWithBn<float>(const framework::Tensor &matrix_a, bool trans_a,
new_bias->data<float>() + group, bias);
#endif
}
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
void MatMulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1) {
......@@ -132,7 +129,7 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int M = dim_out[0];
int N = dim_out[1];
......@@ -146,7 +143,6 @@ void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
gemm.SgemmWithPRelu(M, N, K, matrix_a.data<float>(), K,
matrix_b.data<float>(), N, matrix_out->data<float>(), N,
p, mode, bias, bias1);
#endif
}
......
......@@ -14,7 +14,6 @@ limitations under the License. */
#pragma once
#include <cmath>
#include <string>
#include "framework/tensor.h"
......@@ -22,37 +21,37 @@ namespace paddle_mobile {
namespace operators {
namespace math {
void set_constant(framework::Tensor *tensor, float value);
void SetConstant(framework::Tensor *tensor, float value);
template <typename Itype, typename Otype>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
void MatMul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu = false,
Otype *bias = nullptr);
template <typename Itype, typename Otype>
void matmul(const framework::Tensor &matrix_a, bool trans_a,
void MatMul(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu, Otype *bias,
bool addOnRow);
template <typename T>
void matmulWithBn(const framework::Tensor &matrix_a, bool trans_a,
void MatMulWithBn(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b, float alpha,
framework::Tensor *matrix_out, float beta, bool relu,
framework::Tensor *new_scale, framework::Tensor *new_bias,
int group, T *bias = nullptr);
int group, float *bias = nullptr);
void matmulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
void MatMulWithPRelu(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
framework::Tensor *matrix_out, float *p, std::string mode,
float *bias, float *bias1);
template <typename DeviceType, typename T>
template <typename Device, typename T>
struct ClearTensor {
void operator()(framework::Tensor *tensor);
};
template <typename DeviceType, typename T>
template <typename Device, typename T>
struct RowwiseAdd {
void operator()(const framework::Tensor &input, const framework::Tensor &vec,
framework::Tensor *output);
......
......@@ -22,7 +22,7 @@ namespace operators {
namespace math {
template <>
void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
void MatMul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, int32_t *bias,
......@@ -32,7 +32,7 @@ void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
auto dim_out = matrix_out->dims();
PADDLE_MOBILE_ENFORCE(
dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
"The input and output of matmul be matrix");
"The input and output of MatMul be matrix");
int32_t M = dim_out[0];
int32_t N = dim_out[1];
......@@ -96,11 +96,11 @@ void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
}
template <>
void matmul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
void MatMul<int8_t, int32_t>(const framework::Tensor &matrix_a, bool trans_a,
const framework::Tensor &matrix_b, bool trans_b,
float alpha, framework::Tensor *matrix_out,
float beta, bool relu, int32_t *bias) {
matmul<int8_t, int32_t>(matrix_a, trans_a, matrix_b, trans_b, alpha,
MatMul<int8_t, int32_t>(matrix_a, trans_a, matrix_b, trans_b, alpha,
matrix_out, beta, relu, bias, false);
}
......
......@@ -69,10 +69,10 @@ class LoDTensor2BatchFunctor {
auto lods = lod_tensor.lod();
PADDLE_MOBILE_ENFORCE((lods.size() == 1UL),
"Only support one level sequence now.");
"Only support 1 level sequence, but %d is given",
lods.size());
const auto& lod = lods[0];
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
......
......@@ -15,154 +15,131 @@ limitations under the License. */
#ifdef SOFTMAX_OP
#include "operators/math/softmax.h"
#include "common/types.h"
#ifdef __ARM_NEON
#include <math.h>
#include <algorithm>
#include <limits>
#include "common/types.h"
#include "operators/math/math_func_neon.h"
#endif
namespace paddle_mobile {
namespace operators {
namespace math {
using framework::DDim;
using framework::Tensor;
template <typename T>
class SoftmaxFuntor<CPU, T> {
#ifdef __ARM_NEON
void sum(float *input, float *sumptr, int inner_size, int outter_size) {
float32x4_t acc = vdupq_n_f32(0);
float sum_ = 0;
for (int i = 0; i < outter_size; ++i) {
float *input_outer_ptr = input + i * inner_size;
int nn = inner_size >> 2;
int left = inner_size - (nn << 2);
for (; nn > 0; nn--) {
float32x4_t vec_input = vld1q_f32(input_outer_ptr);
acc = vaddq_f32(acc, vec_input);
input_outer_ptr += 4;
}
float32x2_t vsum_ = vadd_f32(vget_high_f32(acc), vget_low_f32(acc));
sum_ = vget_lane_f32(vsum_, 0) + vget_lane_f32(vsum_, 1);
for (; left > 0; left--) {
sum_ += *input_outer_ptr;
input_outer_ptr++;
}
}
for (int j = 0; j < inner_size * outter_size; ++j) {
sumptr[j] = sum_;
}
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
#ifndef __aarch64__
inline float32_t vmaxvq_f32(const float32x4_t &r) {
float32x2_t v = vmax_f32(vget_high_f32(r), vget_low_f32(r));
return vget_lane_f32(vpmax_f32(v, v), 0);
}
inline float32_t vaddvq_f32(const float32x4_t &r) {
float32x2_t v = vadd_f32(vget_high_f32(r), vget_low_f32(r));
return vget_lane_f32(vpadd_f32(v, v), 0);
}
#endif // __aarch64__
#endif // __ARM_NEON__
float find_max(const float *input, const int num_classes) {
int remain = num_classes;
float max = -std::numeric_limits<float>::max();
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
for (int i = 0; i < loop; ++i, input += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
__max = vmaxq_f32(x0, __max);
__max = vmaxq_f32(x1, __max);
}
max = vmaxvq_f32(__max);
#endif
for (int i = 0; i < remain; ++i) {
max = std::max(max, input[i]);
}
return max;
}
void SoftmaxCacl(const Tensor *X, Tensor *Y) {
const float *input = X->data<float>();
const DDim &dDim = X->dims();
int axis_index = 1;
if (dDim.size() < 4) {
axis_index = 0;
}
DDim outer_ddim =
paddle_mobile::framework::slice_ddim(dDim, 0, axis_index + 1);
DDim inner_ddim =
paddle_mobile::framework::slice_ddim(dDim, axis_index + 1, dDim.size());
int out_size = paddle_mobile::framework::product(outer_ddim);
int inner_size = paddle_mobile::framework::product(inner_ddim);
auto *max_ptr = new float[inner_size * out_size];
// max
for (int j = 0; j < out_size; ++j) {
const float *input_outer_ptr = input + j * inner_size;
float *max_outer_ptr = max_ptr + j * inner_size;
float max_ = 0;
for (int i = 0; i < inner_size; ++i) {
const float *input_inner_ptr = input_outer_ptr + i;
max_ = std::max(max_, input_inner_ptr[0]);
}
for (int k = 0; k < inner_size; ++k) {
max_outer_ptr[k] = max_;
template <>
void SoftmaxFuntor<CPU, float>::operator()(const framework::Tensor *X,
framework::Tensor *Y) {
const framework::DDim &dims = X->dims();
int batch_size = dims[0];
int num_classes = dims[dims.size() - 1];
int channels = X->numel() / batch_size / num_classes;
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
#pragma omp parallel for collapse(2)
for (int batch = 0; batch < X->dims()[0]; ++batch) {
for (int channel = 0; channel < channels; ++channel) {
size_t offset = (batch * channels + channel) * num_classes;
const float *input = x + offset;
float *output = y + offset;
// find max
float max = find_max(input, num_classes);
// exp(x - max)
int remain = num_classes;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
x0 = vsubq_f32(x0, __max);
x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0);
x1 = exp_ps(x1);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
}
// exp(value - max)
float *exp_sub_max = new float[inner_size * out_size];
float *exp_sub_max_ptr = &exp_sub_max[0];
for (int l = 0; l < out_size; ++l) {
const float *input_outer_ptr = input + l * inner_size;
float *max_outer_ptr = max_ptr + l * inner_size;
int nn = inner_size >> 2;
int left = inner_size - (nn << 2);
for (; nn > 0; nn--) {
float32x4_t vec_input = vld1q_f32(input_outer_ptr);
float32x4_t vec_max = vld1q_f32(max_outer_ptr);
float32x4_t vec_sub = vsubq_f32(vec_input, vec_max);
float32x4_t vec_exp = exp_ps(vec_sub);
vst1q_f32(exp_sub_max_ptr, vec_exp);
input_outer_ptr += 4;
max_outer_ptr += 4;
exp_sub_max_ptr += 4;
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
output[i] = expf(input[i] - max);
}
for (; left > 0; left--) {
*exp_sub_max_ptr = expf(*input_outer_ptr - *max_outer_ptr);
input_outer_ptr++;
max_outer_ptr++;
exp_sub_max_ptr++;
// sum(exp(x - max))
float sum = 0.f;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
__sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum);
}
}
float *sumptr = new float[inner_size * out_size];
// sum exp
sum(exp_sub_max, sumptr, inner_size, out_size);
// div
auto *out_ptr = Y->mutable_data<float>();
for (int l = 0; l < out_size; ++l) {
const float *input_outer_ptr = exp_sub_max + l * inner_size;
float *output_outer_ptr = out_ptr + l * inner_size;
float *sum_outer_ptr = sumptr + l * inner_size;
int nn = inner_size >> 2;
int left = inner_size - (nn << 2);
for (; nn > 0; nn--) {
float32x4_t vec_input = vld1q_f32(input_outer_ptr);
float32x4_t vec_sum = vld1q_f32(sum_outer_ptr);
float32x4_t vec_div = div_ps(vec_input, vec_sum);
vst1q_f32(output_outer_ptr, vec_div);
input_outer_ptr += 4;
output_outer_ptr += 4;
sum_outer_ptr += 4;
sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
sum += output[i];
}
for (; left > 0; left--) {
*output_outer_ptr = (*input_outer_ptr) / (*sum_outer_ptr);
input_outer_ptr++;
output_outer_ptr++;
sum_outer_ptr++;
// exp(x - max) / sum
float inv_sum = 1.f / sum;
output = y + offset;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __inv_sum = vdupq_n_f32(inv_sum);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
x0 = vmulq_f32(x0, __inv_sum);
x1 = vmulq_f32(x1, __inv_sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
}
}
#else
#endif // ARM_NEON
public:
void operator()(const framework::Tensor *X, framework::Tensor *Y) {
const DDim dDim = X->dims();
int dim1 = dDim[dDim.size() - 1];
int dim0 = X->numel() / dim1 / dDim[0];
framework::DDim matrix_shape = {dim0, dim1};
for (int i = 0; i < dDim[0]; ++i) {
framework::Tensor sub_X = X->Slice(i, i + 1);
framework::Tensor sub_Y = Y->Slice(i, i + 1);
sub_X.Resize(matrix_shape);
sub_Y.Resize(matrix_shape);
for (int j = 0; j < dim0; j++) {
framework::Tensor sub_x = sub_X.Slice(j, j + 1);
framework::Tensor sub_y = sub_Y.Slice(j, j + 1);
#ifdef __ARM_NEON
SoftmaxCacl(&sub_x, &sub_y);
#endif
for (int i = 0; i < remain; ++i) {
output[i] *= inv_sum;
}
}
}
};
template class SoftmaxFuntor<CPU, float>;
}
} // namespace math
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // SOFTMAX_OP
......@@ -13,17 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SOFTMAX_OP
#pragma once
#include "framework/tensor.h"
namespace paddle_mobile {
namespace operators {
namespace math {
template <typename DeviceType, typename T>
template <typename Device, typename T>
class SoftmaxFuntor {
public:
void operator()(const framework::Tensor *X, framework::Tensor *Y);
};
} // namespace math
} // namespace operators
} // namespace paddle_mobile
......
......@@ -327,8 +327,8 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
int channel = input.dims()[1];
int height = input.dims()[2];
int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height - 8 + 5 + 6) / 6
int w_tiles = (width + 3) / 6; // (width - 8 + 5 + 6) / 6
int h_tiles = (height + 3) / 6; // (height - 2 + 5) / 6
int w_tiles = (width + 3) / 6; // (width - 2 + 5) / 6
int tiles = (h_tiles * w_tiles + 7) / 8;
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{tiles, 64, channel, 8});
......@@ -336,16 +336,10 @@ void winograd_transform_input<8, 3>(const framework::Tensor &input,
memset(outptr, 0, output->numel() * sizeof(float));
const float *inptr = input.data<float>();
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
height = h_tiles * 6 + 2;
width = w_tiles * 6 + 2;
framework::Tensor input_pad;
if (remain_h > 2 || remain_w > 2) {
inter_h += (remain_h > 2);
inter_w += (remain_w > 2);
height = (inter_h - 1) * 6 + 8;
width = (inter_w - 1) * 6 + 8;
if (height > input.dims()[2] || width > input.dims()[3]) {
framework::DDim input_shape =
framework::make_ddim(std::vector<int>{1, channel, height, width});
PadFunctor<CPU, float> pad;
......@@ -878,8 +872,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
framework::Tensor *output) {
// weight shape is [out_channel/4, 64, in_channel, 4],
// input shape is [hw/8, 64, in_channel, 8]
int in_channel = input.dims()[2];
int tiles = input.dims()[0];
int in_channel = input.dims()[2];
int out_channel = weight.dims()[0];
// compute U*V first
......@@ -887,7 +881,6 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64, 32});
float *uv_trans_ptr = uv_trans.mutable_data<float>(shape);
memset(uv_trans_ptr, 0, uv_trans.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
......@@ -910,7 +903,8 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"veor q14, q14, q14 \n"
"veor q15, q15, q15 \n"
"b store_res_%= \n"
"cmp %[inter_channel], #0 \n"
"ble loop_1c_%= \n"
// loop 2 channels
"loop_2c_%=: \n"
"vld1.32 {d0-d3}, [%[w_ptr]]! \n"
......@@ -936,13 +930,14 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"subs %[inter_channel], #1 \n"
"bne loop_2c_%= \n"
"mov pc, lr \n"
// loop 1 channel
"loop_c_%=: \n"
"loop_1c_%=: \n"
"cmp %[remain_channel], #0 \n"
"ble store_res_%= \n"
"vld1.32 {d0-d1}, [%[w_ptr]]! \n"
"vld1.32 {d4-d7}, [%[in_ptr]]! \n"
"vmla.f32 q8, q2, d0[0] \n"
"vmla.f32 q9, q3, d0[0] \n"
"vmla.f32 q10, q2, d0[1] \n"
......@@ -952,28 +947,16 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"vmla.f32 q14, q2, d1[1] \n"
"vmla.f32 q15, q3, d1[1] \n"
"subs %[remain_channel], #1 \n"
"bne loop_c_%= \n"
"mov pc, lr \n"
"store_res_%=: \n"
"cmp %[inter_channel], #0 \n"
"it gt \n"
"blgt loop_2c_%= \n"
"cmp %[remain_channel], #0 \n"
"it gt \n"
"blgt loop_c_%= \n"
"vst1.32 {d16-d19}, [%[uv_ptr]]! \n"
"vst1.32 {d20-d23}, [%[uv_ptr]]! \n"
"vst1.32 {d24-d27}, [%[uv_ptr]]! \n"
"vst1.32 {d28-d31}, [%[uv_ptr]]! \n"
: [w_ptr] "+r"(w_ptr), [in_ptr] "+r"(in_ptr), [uv_ptr] "+r"(uv_ptr),
[remain_channel] "+r"(remain_channel),
[inter_channel] "+r"(inter_channel)
:
: [remain_channel] "r"(remain_channel)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15", "pc", "lr");
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
}
}
}
......@@ -1223,8 +1206,10 @@ void winograd_transform_output<8, 3>(const framework::Tensor &input,
"q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
size_t offset = (oc * out_h + 6 * tile_h) * out_w + 6 * tile_w;
float *out_ptr = output_ptr + offset;
int remain_row = (tile_h < h_tiles - 1) ? 6 : remain_h;
int remain_col = (tile_w < w_tiles - 1) ? 6 : remain_w;
int remain_row = out_h - 6 * tile_h;
int remain_col = out_w - 6 * tile_w;
remain_row = (remain_row > 6) ? 6 : remain_row;
remain_col = (remain_col > 6) ? 6 : remain_col;
for (int i = 0; i < remain_row; ++i, out_ptr += out_w) {
memcpy(out_ptr, output_tmp + i * 6, remain_col * sizeof(float));
}
......
......@@ -12,14 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Inspired by https://arxiv.org/abs/1509.09308 and refered from nnpack and ncnn
// project.
// We refer https://github.com/andravin/wincnn to access the winograd transform
// matrixs
#ifdef CONV_OP
#ifdef __aarch64__
#include "operators/math/pad.h"
#include "operators/math/winograd/winograd_transform.h"
namespace paddle_mobile {
......@@ -29,46 +27,382 @@ namespace math {
template <>
void winograd_transform_weight<8, 3>(const framework::Tensor &weight,
framework::Tensor *output) {
/*
* w0 = g0
* w1 = ((g0 + g2) + g1) * (-2.0 / 9)
* w2 = ((g0 + g2) - g1) * (-2.0 / 9)
* w3 = ((g0 + 4 * g2) + 2 * g1) * (1.0 / 90)
* w4 = ((g0 + 4 * g2) - 2 * g1) * (1.0 / 90)
* w5 = ((g2 + 4 * g0) + 2 * g1) * (1.0 / 180)
* w6 = ((g2 + 4 * g0) - 2 * g1) * (1.0 / 180)
* w7 = g2
*/
// TODO(hjchen2)
PADDLE_MOBILE_THROW_EXCEPTION(
"Winograd for arm v8 has not been implemented.");
// weight shape is [out_channel, in_channel, kernel_h, kernel_w]
int out_channel = weight.dims()[0];
int in_channel = weight.dims()[1];
// reshape and alloc transformed weight
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{out_channel, in_channel, 64});
float *outptr = output->mutable_data<float>(transformed_shape);
const float *inptr = weight.data<float>();
for (int oc = 0; oc < out_channel; ++oc) {
for (int ic = 0; ic < in_channel; ++ic) {
size_t offset = oc * in_channel + ic;
float *kout = outptr + offset * 64;
const float *k = inptr + offset * 9;
float gw[3][8];
for (int i = 0; i < 3; ++i, k += 3) {
float g0 = k[0];
float g1 = k[1];
float g2 = k[2];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
gw[i][0] = g0;
gw[i][1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (g0 + g1 + g2)
gw[i][2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (g0 - g1 + g2)
gw[i][3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (g0 + 2 * g1 + 4 * g2)
gw[i][4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (g0 - 2 * g1 + 4 * g2)
gw[i][5] = 1.f / 180 * (d2 + d3); // 1.f/180 * (4 * g0 + 2 * g1 + g2)
gw[i][6] = 1.f / 180 * (d2 - d3); // 1.f/180 * (4 * g0 - 2 * g1 + g2)
gw[i][7] = g2;
}
for (int i = 0; i < 8; ++i, kout += 8) {
float g0 = gw[0][i];
float g1 = gw[1][i];
float g2 = gw[2][i];
float d0 = g0 + g2;
float d1 = g0 + 4 * g2;
float d2 = g2 + 4 * g0;
float d3 = 2 * g1;
kout[0] = g0;
kout[1] = -2.f / 9 * (d0 + g1); // -2.f/9 * (k0 + k1 + k2)
kout[2] = -2.f / 9 * (d0 - g1); // -2.f/9 * (k0 - k1 + k2)
kout[3] = 1.f / 90 * (d1 + d3); // 1.f/90 * (k0 + 2 * k1 + 4 * k2)
kout[4] = 1.f / 90 * (d1 - d3); // 1.f/90 * (k0 - 2 * k1 + 4 * k2)
kout[5] = 1.f / 180 * (d2 + d3); // 8.f/45 * (4 * k0 + 2 * k1 + k2)
kout[6] = 1.f / 180 * (d2 - d3); // 8.f/45 * (4 * k0 - 2 * k1 + k2)
kout[7] = g2;
}
}
}
}
template <>
void winograd_transform_input<8, 3>(const framework::Tensor &input,
framework::Tensor *output) {
/*
* x0 = (d0 - d6) + (d4 - d2) * 5.25
* x1 = (d2 + d6) - 4.25 * (d4 + d3) + (d1 + d5)
* x2 = (d2 + d6) - 4.25 * (d4 - d3) - (d1 + d5)
* x3 = (0.25 * d2 - 1.25 * d4 + d6) + (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x4 = (0.25 * d2 - 1.25 * d4 + d6) - (0.5 * d1 - 2.5 * d3 + 2 * d5)
* x5 = (4 * d2 - 5 * d4 + d6) + (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x6 = (4 * d2 - 5 * d4 + d6) - (2 * d1 - 2.5 * d3 + 0.5 * d5)
* x7 = (d7 - d1) + (d3 - d5) * 5.25
*/
// TODO(hjchen2)
PADDLE_MOBILE_THROW_EXCEPTION(
"Winograd for arm v8 has not been implemented.");
// tile input to [c, roundup(h/6), roundup(w/6), 64] and do transformation
int channel = input.dims()[1];
int height = input.dims()[2];
int width = input.dims()[3];
int h_tiles = (height + 3) / 6; // (height + 5 - 2) / 6
int w_tiles = (width + 3) / 6; // (width + 5 - 2) / 6
framework::DDim transformed_shape =
framework::make_ddim(std::vector<int>{channel, h_tiles, w_tiles, 64});
float *outptr = output->mutable_data<float>(transformed_shape);
memset(outptr, 0, channel * h_tiles * w_tiles * 64 * sizeof(float));
const float *inptr = input.data<float>();
// pack input to tiles
for (int c = 0; c < channel; ++c) {
int inter_h = (height - 2) / 6;
int inter_w = (width - 2) / 6;
int remain_h = height - (inter_h * 6);
int remain_w = width - (inter_w * 6);
const float *in0 = inptr + c * height * width;
const float *in1 = in0 + width;
const float *in2 = in1 + width;
const float *in3 = in2 + width;
const float *in4 = in3 + width;
const float *in5 = in4 + width;
const float *in6 = in5 + width;
const float *in7 = in6 + width;
float *out = outptr + c * h_tiles * w_tiles * 64;
for (int h = 0; h < inter_h; ++h) {
for (int w = 0; w < inter_w; ++w) {
memcpy(out, in0, 8 * sizeof(float));
memcpy(out + 8, in1, 8 * sizeof(float));
memcpy(out + 16, in2, 8 * sizeof(float));
memcpy(out + 24, in3, 8 * sizeof(float));
memcpy(out + 32, in4, 8 * sizeof(float));
memcpy(out + 40, in5, 8 * sizeof(float));
memcpy(out + 48, in6, 8 * sizeof(float));
memcpy(out + 56, in7, 8 * sizeof(float));
in0 += 6;
in1 += 6;
in2 += 6;
in3 += 6;
in4 += 6;
in5 += 6;
in6 += 6;
in7 += 6;
out += 64;
}
// remain width
if (remain_w > 2) {
memcpy(out, in0, remain_w * sizeof(float));
memcpy(out + 8, in1, remain_w * sizeof(float));
memcpy(out + 16, in2, remain_w * sizeof(float));
memcpy(out + 24, in3, remain_w * sizeof(float));
memcpy(out + 32, in4, remain_w * sizeof(float));
memcpy(out + 40, in5, remain_w * sizeof(float));
memcpy(out + 48, in6, remain_w * sizeof(float));
memcpy(out + 56, in7, remain_w * sizeof(float));
out += 64;
}
in0 += 5 * width + remain_w;
in1 += 5 * width + remain_w;
in2 += 5 * width + remain_w;
in3 += 5 * width + remain_w;
in4 += 5 * width + remain_w;
in5 += 5 * width + remain_w;
in6 += 5 * width + remain_w;
in7 += 5 * width + remain_w;
}
// remain height
if (remain_h > 2) {
for (int w = 0; w < inter_w; ++w) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, 8 * sizeof(float));
}
out += 64;
in0 += 6;
}
// remain width
if (remain_w > 2) {
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out + rh * 8, in0 + rh * width, remain_w * sizeof(float));
}
}
}
}
// transform tiles, compute B_T * d(c, b) * B
for (int c = 0; c < channel; ++c) {
for (int tile = 0; tile < h_tiles * w_tiles; ++tile) {
float *out = outptr + (c * h_tiles * w_tiles + tile) * 64;
// compute B_T * d(c, b)
float bd[8][8];
for (int i = 0; i < 8; ++i) {
float d0 = out[8 * i + 0];
float d1 = out[8 * i + 1];
float d2 = out[8 * i + 2];
float d3 = out[8 * i + 3];
float d4 = out[8 * i + 4];
float d5 = out[8 * i + 5];
float d6 = out[8 * i + 6];
float d7 = out[8 * i + 7];
bd[i][0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
bd[i][1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
bd[i][2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
bd[i][3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
bd[i][4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
bd[i][5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
bd[i][6] = v1 - v2;
bd[i][7] = d7 - d1 + (d3 - d5) * 5.25;
}
// compute B_T * d(c, b) * B
for (int i = 0; i < 8; ++i, out += 8) {
float d0 = bd[0][i];
float d1 = bd[1][i];
float d2 = bd[2][i];
float d3 = bd[3][i];
float d4 = bd[4][i];
float d5 = bd[5][i];
float d6 = bd[6][i];
float d7 = bd[7][i];
out[0] = d0 - d6 + (d4 - d2) * 5.25;
float v1 = d2 - 4.25 * d4 + d6;
float v2 = d1 - 4.25 * d3 + d5;
// d1 + d2 - 4.25 * d3 - 4.25 * d4 + d5 + d6
out[1] = v1 + v2;
// -d1 + d2 + 4.25 * d3 - 4.25 * d4 - d5 + d6
out[2] = v1 - v2;
v1 = 0.25 * d2 - 1.25 * d4 + d6;
v2 = 0.5 * d1 - 2.5 * d3 + 2 * d5;
// 0.5 * d1 + 0.25 * d2 - 2.5 * d3 - 1.25 * d4 + 2 * d5 + d6
out[3] = v1 + v2;
// -0.5 * d1 + 0.25 * d2 + 2.5 * d3 - 1.25 * d4 - 2 * d5 + d6
out[4] = v1 - v2;
v1 = 4 * d2 - 5 * d4 + d6;
v2 = 2 * d1 - 2.5 * d3 + 0.5 * d5;
// 2 * d1 + 4 * d2 - 2.5 * d3 - 5 * d4 + 0.5 * d5 + d6
out[5] = v1 + v2;
// -2 * d1 + 4 * d2 + 2.5 * d3 - 5 * d4 - 0.5 * d5 + d6
out[6] = v1 - v2;
out[7] = d7 - d1 + (d3 - d5) * 5.25;
}
}
}
}
template <>
void winograd_transform_output<8, 3>(const framework::Tensor &input,
const framework::Tensor &weight,
framework::Tensor *output) {
// TODO(hjchen2)
PADDLE_MOBILE_THROW_EXCEPTION(
"Winograd for arm v8 has not been implemented.");
// input shape is [in_channel, h_tiles, w_tiles, 64]
// weight shape is [out_channel, in_channel, 64]
int in_channel = input.dims()[0];
int h_tiles = input.dims()[1];
int w_tiles = input.dims()[2];
int tiles = h_tiles * w_tiles;
int out_channel = weight.dims()[0];
// compute U*V first
framework::Tensor output_m;
framework::DDim shape =
framework::make_ddim(std::vector<int>{out_channel, tiles, 64});
float *output_m_ptr = output_m.mutable_data<float>(shape);
memset(output_m_ptr, 0, output_m.numel() * sizeof(float));
const float *input_ptr = input.data<float>();
const float *weight_ptr = weight.data<float>();
for (int i = 0; i < out_channel; ++i) {
for (int j = 0; j < tiles; ++j) {
const float *w_ptr = weight_ptr + i * in_channel * 64;
const float *in_ptr = input_ptr + j * 64;
float *m_ptr = output_m_ptr + (i * tiles + j) * 64;
for (int c = 0; c < in_channel; ++c) {
for (int k = 0; k < 64; ++k) {
m_ptr[k] += w_ptr[k] * in_ptr[k];
}
w_ptr += 64;
in_ptr += tiles * 64;
}
}
}
for (int oc = 0; oc < out_channel; ++oc) {
for (int tile = 0; tile < tiles; ++tile) {
float *m = output_m_ptr + (oc * tiles + tile) * 64;
// compute A_T * m
float am[6][8];
for (int i = 0; i < 8; ++i) {
float d0 = m[i * 8 + 0];
float d1 = m[i * 8 + 1];
float d2 = m[i * 8 + 2];
float d3 = m[i * 8 + 3];
float d4 = m[i * 8 + 4];
float d5 = m[i * 8 + 5];
float d6 = m[i * 8 + 6];
float d7 = m[i * 8 + 7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
am[0][i] = d0 + v0 + v2 + 32 * v4;
am[1][i] = v1 + 2 * v3 + 16 * v5;
am[2][i] = v0 + 4 * v2 + 8 * v4;
am[3][i] = v1 + 8 * v3 + 4 * v5;
am[4][i] = v0 + 16 * v2 + 2 * v4;
am[5][i] = v1 + 32 * v3 + v5 + d7;
}
// compute A_T * m * A
for (int i = 0; i < 6; ++i, m += 8) {
float d0 = am[i][0];
float d1 = am[i][1];
float d2 = am[i][2];
float d3 = am[i][3];
float d4 = am[i][4];
float d5 = am[i][5];
float d6 = am[i][6];
float d7 = am[i][7];
float v0 = d1 + d2;
float v1 = d1 - d2;
float v2 = d3 + d4;
float v3 = d3 - d4;
float v4 = d5 + d6;
float v5 = d5 - d6;
m[0] = d0 + v0 + v2 + 32 * v4;
m[1] = v1 + 2 * v3 + 16 * v5;
m[2] = v0 + 4 * v2 + 8 * v4;
m[3] = v1 + 8 * v3 + 4 * v5;
m[4] = v0 + 16 * v2 + 2 * v4;
m[5] = v1 + 32 * v3 + v5 + d7;
}
}
}
int out_h = output->dims()[2];
int out_w = output->dims()[3];
float *output_ptr = output->mutable_data<float>();
// copy valid region to final output
for (int oc = 0; oc < out_channel; ++oc) {
int inter_h = out_h / 6;
int inter_w = out_w / 6;
int remain_h = out_h - inter_h * 6;
int remain_w = out_w - inter_w * 6;
float *out_ptr0 = output_ptr + oc * out_h * out_w;
float *out_ptr1 = out_ptr0 + out_w;
float *out_ptr2 = out_ptr1 + out_w;
float *out_ptr3 = out_ptr2 + out_w;
float *out_ptr4 = out_ptr3 + out_w;
float *out_ptr5 = out_ptr4 + out_w;
const float *m_ptr = output_m_ptr + oc * tiles * 64;
for (int tile_h = 0; tile_h < inter_h; ++tile_h) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (tile_h * w_tiles + tile_w) * 64;
memcpy(out_ptr0, m, 6 * sizeof(float));
memcpy(out_ptr1, m + 8, 6 * sizeof(float));
memcpy(out_ptr2, m + 16, 6 * sizeof(float));
memcpy(out_ptr3, m + 24, 6 * sizeof(float));
memcpy(out_ptr4, m + 32, 6 * sizeof(float));
memcpy(out_ptr5, m + 40, 6 * sizeof(float));
out_ptr0 += 6;
out_ptr1 += 6;
out_ptr2 += 6;
out_ptr3 += 6;
out_ptr4 += 6;
out_ptr5 += 6;
}
// remain w
if (remain_w > 0) {
const float *m = m_ptr + (tile_h * w_tiles + inter_w) * 64;
memcpy(out_ptr0, m, remain_w * sizeof(float));
memcpy(out_ptr1, m + 8, remain_w * sizeof(float));
memcpy(out_ptr2, m + 16, remain_w * sizeof(float));
memcpy(out_ptr3, m + 24, remain_w * sizeof(float));
memcpy(out_ptr4, m + 32, remain_w * sizeof(float));
memcpy(out_ptr5, m + 40, remain_w * sizeof(float));
out_ptr0 += remain_w;
out_ptr1 += remain_w;
out_ptr2 += remain_w;
out_ptr3 += remain_w;
out_ptr4 += remain_w;
out_ptr5 += remain_w;
}
out_ptr0 += 5 * out_w;
out_ptr1 += 5 * out_w;
out_ptr2 += 5 * out_w;
out_ptr3 += 5 * out_w;
out_ptr4 += 5 * out_w;
out_ptr5 += 5 * out_w;
}
// remain h
if (remain_h > 0) {
for (int tile_w = 0; tile_w < inter_w; ++tile_w) {
const float *m = m_ptr + (inter_h * w_tiles + tile_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, 6 * sizeof(float));
}
out_ptr0 += 6;
}
if (remain_w > 0) {
const float *m = m_ptr + (inter_h * w_tiles + inter_w) * 64;
for (int rh = 0; rh < remain_h; ++rh) {
memcpy(out_ptr0 + rh * out_w, m + rh * 8, remain_w * sizeof(float));
}
}
}
}
}
} // namespace math
......
......@@ -439,10 +439,11 @@ class ConvParam : public OpParam {
#endif
protected:
public:
RType *input_;
RType *output_;
RType *filter_;
RType *transformed_filter_;
vector<int> strides_;
vector<int> paddings_;
vector<int> dilations_;
......@@ -455,7 +456,7 @@ class ConvParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
public:
fpga::SplitConvArgs fpga_conv_args;
public:
......@@ -2515,6 +2516,52 @@ class ShapeParam : public OpParam {
};
#endif
#ifdef TOP_K_OP
template <typename Dtype>
class TopKParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
TopKParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = OpParam::GetVarValue<GType>("X", inputs, scope);
output_ = OpParam::GetVarValue<GType>("Out", outputs, scope);
indices_ = OpParam::GetVarValue<GType>("Indices", outputs, scope);
k_ = OpParam::GetAttr<int>("k", attrs);
}
public:
RType *input_;
RType *output_;
RType *indices_;
int k_;
};
#endif // TOP_K_OP
#ifdef CAST_OP
template <typename Dtype>
class CastParam : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
CastParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_ = OpParam::GetVarValue<GType>("X", inputs, scope);
output_ = OpParam::GetVarValue<GType>("Out", outputs, scope);
input_type_ = OpParam::GetAttr<int>("in_dtype", attrs);
output_type_ = OpParam::GetAttr<int>("out_dtype", attrs);
}
public:
RType *input_;
RType *output_;
int input_type_;
int output_type_;
};
#endif // CAST_OP
#ifdef QUANT_OP
template <typename Dtype>
class QuantizeParam : public OpParam {
......@@ -2542,9 +2589,9 @@ class QuantizeParam : public OpParam {
public:
// op input
RType *input_;
GType *input_;
// op output
RType *output_;
GType *output_;
RType *online_scale_;
// quantize offline scale
RType *offline_scale_;
......@@ -2578,9 +2625,9 @@ class DequantizeParam : public OpParam {
public:
// op input
RType *input_;
GType *input_;
// op output
RType *output_;
GType *output_;
RType *activation_scale_;
float weight_scale_;
};
......
......@@ -36,4 +36,4 @@ namespace ops = paddle_mobile::operators;
REGISTER_OPERATOR_CPU(quantize, ops::QuantizeOp);
#endif
#endif
#endif // QUANT_OP
......@@ -43,4 +43,4 @@ class QuantizeOp : public framework::OperatorWithKernel<
} // namespace operators
} // namespace paddle_mobile
#endif
#endif // QUANT_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TOP_K_OP
#include "operators/top_k_op.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void TopKOp<DeviceType, T>::InferShape() const {
const int k = this->param_.k_;
auto dims = this->param_.input_->dims();
// should check k <= dims[-1] && k >= 1
dims[dims.size() - 1] = k;
this->param_.output_->Resize(dims);
this->param_.indices_->Resize(dims);
}
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(top_k, ops::TopKOp);
#endif
#endif // TOP_K_OP
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TOP_K_OP
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/kernels.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class TopKOp : public framework::OperatorWithKernel<
DeviceType, TopKParam<DeviceType>,
operators::TopKKernel<DeviceType, T>> {
public:
TopKOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, TopKParam<DeviceType>,
operators::TopKKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// inference output shape
void InferShape() const override;
};
} // namespace operators
} // namespace paddle_mobile
#endif // TOP_K_OP
......@@ -261,20 +261,17 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp)
target_link_libraries(test-inference-api paddle-mobile)
# gen test log
# gen test
ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp)
target_link_libraries(test-optimize paddle-mobile)
#gen test
ADD_EXECUTABLE(test-pool-op operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool-op paddle-mobile)
#gen test
ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-softmax paddle-mobile)
ADD_EXECUTABLE(test-softmax-op operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-softmax-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp)
......@@ -375,5 +372,8 @@ if (NOT FOUND_MATCH)
# gen test
ADD_EXECUTABLE(test-super net/test_super.cpp test_helper.h test_include.h)
target_link_libraries(test-super paddle-mobile)
#add_library(test-lib-size SHARED common/test_lib_size.h common/test_lib_size.cpp)
# gen test
ADD_EXECUTABLE(test-ocr net/test_ocr.cpp test_helper.h test_include.h)
target_link_libraries(test-ocr paddle-mobile)
endif ()
......@@ -73,14 +73,14 @@ int main() {
// float
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float, float>(
paddle_mobile::operators::math::MatMul<float, float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
}
auto time_start0 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<float, float>(
paddle_mobile::operators::math::MatMul<float, float>(
aa, false, bb, false, static_cast<float>(1), &cc, static_cast<float>(0),
false, nullptr);
}
......@@ -91,14 +91,14 @@ int main() {
// int8_t without bias
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
auto time_start1 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(1), &cc_int32,
static_cast<float>(0));
}
......@@ -109,13 +109,13 @@ int main() {
// int8_t with bias, column element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
auto time_start2 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_col, false);
}
......@@ -126,13 +126,13 @@ int main() {
// int8_t with bias, row element wise add
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
auto time_start3 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), false, bias_data_row, true);
}
......@@ -143,13 +143,13 @@ int main() {
// int8_t with bias&relu
// warm-up 10 times
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
auto time_start4 = time();
for (int j = 0; j < 10; ++j) {
paddle_mobile::operators::math::matmul<int8_t, int32_t>(
paddle_mobile::operators::math::MatMul<int8_t, int32_t>(
aa_int8, false, bb_int8, false, static_cast<float>(0.618), &cc_int8,
static_cast<float>(0), true, bias_data_col, false);
}
......
......@@ -39,6 +39,7 @@ using paddle_mobile::framework::Tensor;
using paddle_mobile::framework::Variable;
using std::string;
using std::vector;
template <typename DeviceType, typename OpType>
class Executor4Test : public Executor<DeviceType> {
public:
......@@ -48,20 +49,19 @@ class Executor4Test : public Executor<DeviceType> {
this->use_optimize_ = use_optimize;
this->program_ = p;
if (this->use_optimize_) {
this->to_predict_program_ = this->program_.optimizeProgram;
this->program_desc_ = this->program_.optimizeProgram;
} else {
this->to_predict_program_ = this->program_.originProgram;
this->program_desc_ = this->program_.originProgram;
}
if (this->program_.originProgram == nullptr) {
LOG(paddle_mobile::LogLevel::kLOG_ERROR)
<< "to_predict_program_ == nullptr";
LOG(paddle_mobile::LogLevel::kLOG_ERROR) << "program_desc_ == nullptr";
}
const std::vector<std::shared_ptr<BlockDesc>> blocks =
this->to_predict_program_->Blocks();
for (std::shared_ptr<BlockDesc> block_desc : blocks) {
std::vector<std::shared_ptr<OpDesc>> ops = block_desc->Ops();
this->program_desc_->Blocks();
for (int block_id = 0; block_id < blocks.size(); ++block_id) {
std::vector<std::shared_ptr<OpDesc>> ops = blocks[block_id]->Ops();
for (int i = 0; i < ops.size(); ++i) {
auto op = ops[i];
if (op->Type() == op_type) {
......@@ -73,18 +73,16 @@ class Executor4Test : public Executor<DeviceType> {
paddle_mobile::framework::OpRegistry<DeviceType>::CreateOp(
op->Type(), op->GetInputs(), op->GetOutputs(),
op->GetAttrMap(), this->program_.scope);
this->ops_of_block_[*block_desc.get()].push_back(op_ptr);
this->ops_of_block_[block_id].push_back(op_ptr);
break;
}
}
}
this->InitMemory();
std::shared_ptr<paddle_mobile::framework::BlockDesc> to_predict_block =
this->to_predict_program_->Block(0);
auto &ops = this->ops_of_block_[*to_predict_block.get()];
for (const auto &op : ops) {
op->Init();
for (const auto &ops : this->ops_of_block_) {
for (const auto &op : ops) {
op->Init();
}
}
}
......@@ -117,12 +115,10 @@ class Executor4Test : public Executor<DeviceType> {
output_tensor_sptrs[i].reset(output_tensors[i]);
}
std::shared_ptr<paddle_mobile::framework::BlockDesc> to_predict_block =
this->to_predict_program_->Block(0);
for (int j = 0; j < this->ops_of_block_[*to_predict_block.get()].size();
++j) {
auto op = this->ops_of_block_[*to_predict_block.get()][j];
op->Run();
for (auto &ops : this->ops_of_block_) {
for (auto &op : ops) {
op->Run();
}
}
return output_tensor_sptrs;
......@@ -139,14 +135,11 @@ class Executor4Test : public Executor<DeviceType> {
auto *output_tensor = con_output->GetMutable<LoDTensor>();
output_tensor->mutable_data<float>(dDim);
std::shared_ptr<paddle_mobile::framework::BlockDesc> to_predict_block =
this->to_predict_program_->Block(0);
for (int j = 0; j < this->ops_of_block_[*to_predict_block.get()].size();
++j) {
auto op = this->ops_of_block_[*to_predict_block.get()][j];
op->Run();
for (auto &ops : this->ops_of_block_) {
for (auto &op : ops) {
op->Run();
}
}
return std::make_shared<paddle_mobile::framework::Tensor>(
paddle_mobile::framework::Tensor(*output_tensor));
}
......
......@@ -52,15 +52,16 @@ int main(int argc, char* argv[]) {
SetupTensor<float>(&input, in_shape, 0.f, 255.f);
// warmup
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
paddle_mobile.Predict(input);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
output = paddle_mobile.Predict(input);
paddle_mobile.Predict(input);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms\n";
std::ostringstream os("output tensor size: ");
output = paddle_mobile.Fetch();
os << output->numel() << "\n" << output->data<float>()[0];
for (int i = 1; i < output->numel(); ++i) {
os << ", " << output->data<float>()[i];
......
......@@ -36,11 +36,11 @@ int main() {
input_tensor.data<float>() + input_tensor.numel());
// 预热十次
for (int i = 0; i < 1; ++i) {
paddle_mobile.PredictLod(input_tensor);
paddle_mobile.Predict(input_tensor);
}
auto time3 = time();
for (int i = 0; i < 1; ++i) {
paddle_mobile.PredictLod(input_tensor);
paddle_mobile.Predict(input_tensor);
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) << "ms"
......
......@@ -19,7 +19,7 @@ limitations under the License. */
int main(int argc, char* argv[]) {
if (argc < 2) {
std::cout << "Usage: ./test_benchmark feed_shape [thread_num] [use_fuse]\n"
<< "feed_shape: input tensor shape, such as 1,3,224,224.\n"
<< "feed_shape: input tensor shape, such as 3,224,224.\n"
<< "thread_num: optional int, threads count, default is 1.\n"
<< "use_fuse: optional bool, default is 0.\n";
return 1;
......@@ -41,18 +41,18 @@ int main(int argc, char* argv[]) {
#endif
paddle_mobile.SetThreadNum(thread_num);
auto time1 = time();
if (paddle_mobile.Load(g_googlenet, optimize)) {
std::vector<float> output;
if (paddle_mobile.Load(g_googlenet, optimize, false, 1, true)) {
auto time2 = paddle_mobile::time();
std::cout << "load cost :" << paddle_mobile::time_diff(time1, time2) << "ms"
<< std::endl;
std::vector<float> input;
std::vector<float> output;
std::vector<int64_t> dims{1, 3, 224, 224};
if (feed_shape) {
sscanf(feed_shape, "%d,%d,%d", &dims[1], &dims[2], &dims[3]);
}
std::cout << "feed shape: [" << dims[0] << ", " << dims[1] << ", "
<< dims[2] << ", " << dims[3] << "]\n";
<< dims[2] << ", " << dims[3] << "]" << std::endl;
GetInput<float>(g_test_image_1x3x224x224, &input, dims);
// warmup
for (int i = 0; i < 10; ++i) {
......
......@@ -48,8 +48,8 @@ int main() {
DLOG << "words lod 22: " << words.lod();
auto time3 = time();
for (int i = 0; i < 1; ++i) {
auto vec_result = paddle_mobile.PredictLod(words);
DLOG << *vec_result;
paddle_mobile.Predict(words);
DLOG << *paddle_mobile.Fetch();
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 1 << "ms"
......@@ -84,8 +84,8 @@ int main() {
DLOG << "words lod 22: " << words.lod();
auto time3 = time();
for (int i = 0; i < 1; ++i) {
auto vec_result = paddle_mobile.PredictLod(words);
DLOG << *vec_result;
paddle_mobile.Predict(words);
DLOG << *paddle_mobile.Fetch();
}
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 1 << "ms"
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fstream>
#include <iostream>
#include "../test_helper.h"
#include "../test_include.h"
void load_images(const char *image_dir, const char *images_list,
std::vector<std::string> *image_names,
std::vector<std::pair<int, int>> *image_shapes) {
int height, width;
std::string filename;
std::ifstream if_list(images_list, std::ios::in);
while (!if_list.eof()) {
if_list >> height >> width >> filename;
image_shapes->push_back(std::make_pair(height, width));
image_names->push_back(filename);
}
if_list.close();
}
int main(int argc, char **argv) {
if (argc < 4) {
std::cerr << "Usage: ./test_ocr model_dir image_dir images_list."
<< std::endl;
return 1;
}
char *model_dir = argv[1];
char *image_dir = argv[2];
char *images_list = argv[3];
paddle_mobile::PaddleMobile<paddle_mobile::CPU> paddle_mobile;
paddle_mobile.SetThreadNum(8);
auto isok = paddle_mobile.Load(std::string(model_dir) + "/model",
std::string(model_dir) + "/params", true,
false, 1, true);
DLOG << "pass init model";
std::vector<std::string> image_names;
std::vector<std::pair<int, int>> image_shapes;
load_images(image_dir, images_list, &image_names, &image_shapes);
DLOG << "pass load images";
for (int i = 0; i < image_names.size(); i++) {
std::string file_name = image_names[i];
std::vector<float> input_vec;
std::vector<int64_t> dims{1, 1, 48, 512};
dims[2] = image_shapes[i].first;
dims[3] = image_shapes[i].second;
// load input image
std::string img_path = std::string(image_dir) + "/" + file_name;
std::cerr << "img_path: " << img_path << std::endl;
std::cerr << "shape = [" << dims[0] << ", " << dims[1] << ", " << dims[2]
<< ", " << dims[3] << "]" << std::endl;
GetInput<float>(img_path, &input_vec, dims);
framework::Tensor input(input_vec, framework::make_ddim(dims));
// predict
paddle_mobile.Predict(input);
auto output_topk = paddle_mobile.Fetch("top_k_1.tmp_0");
auto output_indices = paddle_mobile.Fetch("cast_68.tmp_0");
// print result
std::cerr << file_name << std::endl;
std::cerr << output_topk->data<float>()[0];
for (int j = 1; j < output_topk->numel(); ++j) {
std::cerr << " " << output_topk->data<float>()[j];
}
std::cerr << std::endl;
std::cerr << output_indices->data<float>()[0];
for (int j = 1; j < output_indices->numel(); ++j) {
std::cerr << " " << output_indices->data<float>()[j];
}
std::cerr << std::endl;
}
return 0;
}
......@@ -12,29 +12,88 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include "../test_include.h"
#include "operators/softmax_op.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
auto program = loader.Load(std::string(g_mobilenet));
if (program.originProgram == nullptr) {
DLOG << "program read file";
namespace paddle_mobile {
void Softmax(const framework::Tensor *X, framework::Tensor *Y) {
const framework::DDim &dims = X->dims();
int batch_size = dims[0];
int num_classes = dims[dims.size() - 1];
int channels = X->numel() / batch_size / num_classes;
const float *x = X->data<float>();
float *y = Y->mutable_data<float>();
for (int batch = 0; batch < batch_size; ++batch) {
for (int c = 0; c < channels; ++c) {
size_t offset = (batch * channels + c) * num_classes;
const float *input = x + offset;
float *output = y + offset;
float max = -std::numeric_limits<float>::max();
for (int j = 0; j < num_classes; ++j) {
max = (input[j] > max) ? input[j] : max;
}
float sum = 0.f;
for (int j = 0; j < num_classes; ++j) {
float tmp = std::expf(input[j] - max);
sum += tmp;
output[j] = tmp;
}
for (int j = 0; j < num_classes; ++j) {
output[j] /= sum;
}
}
}
Executor4Test<paddle_mobile::CPU,
paddle_mobile::operators::SoftmaxOp<paddle_mobile::CPU, float>>
executor(program, "softmax");
paddle_mobile::framework::Tensor input;
SetupTensor<float>(&input, {1, 1000}, static_cast<float>(0),
static_cast<float>(1));
auto out_ddim = paddle_mobile::framework::make_ddim({1, 1000});
auto output =
executor.Predict(input, "reshape_0.tmp_0", "softmax_0.tmp_0", out_ddim);
auto *output_ptr = output->data<float>();
for (int j = 0; j < output->numel(); ++j) {
DLOG << " value of output: " << output_ptr[j];
}
int TestSoftmaxOp(const std::vector<int> input_shape) {
framework::DDim dims = framework::make_ddim(input_shape);
VariableNameMap inputs;
VariableNameMap outputs;
auto scope = std::make_shared<framework::Scope>();
inputs["X"] = std::vector<std::string>({"input"});
outputs["Out"] = std::vector<std::string>({"output"});
auto input_var = scope.get()->Var("input");
auto input = input_var->template GetMutable<framework::LoDTensor>();
SetupTensor<float>(input, dims, -100.0, 100.0);
auto output_var = scope.get()->Var("output");
auto output = output_var->template Get<framework::LoDTensor>();
framework::AttributeMap attrs;
auto *op = new operators::SoftmaxOp<CPU, float>("softmax", inputs, outputs,
attrs, scope);
op->InferShape();
op->Init();
op->Run();
framework::Tensor output_cmp;
float *output_cmp_data = output_cmp.mutable_data<float>(output->dims());
Softmax(input, &output_cmp);
const float *output_data = output->data<float>();
for (int i = 0; i < output->numel(); ++i) {
float gap = output_data[i] - output_cmp_data[i];
if (std::abs(gap / (output_data[i] + 1e-5)) > 1e-3) {
LOG(kLOG_INFO) << "output_data[" << i << "] = " << output_data[i]
<< ", output_cmp_data[" << i
<< "] = " << output_cmp_data[i];
delete op;
exit(1);
}
}
delete op;
return 0;
}
} // namespace paddle_mobile
int main(int argc, char *argv[]) {
TestSoftmaxOp({128, 1000});
TestSoftmaxOp({128, 10, 1000});
return 0;
}
......@@ -247,6 +247,8 @@ if(NOT FOUND_MATCH)
set(SHAPE_OP ON)
set(ELEMENTWISEMUL_OP ON)
set(SUM_OP ON)
set(TOP_K_OP ON)
set(CAST_OP ON)
set(QUANT_OP ON)
set(DEQUANT_OP ON)
set(FUSION_DEQUANT_BN_OP ON)
......@@ -449,7 +451,12 @@ endif()
if (SUM_OP)
add_definitions(-DSUM_OP)
endif()
if (TOP_K_OP)
add_definitions(-DTOP_K_OP)
endif()
if (CAST_OP)
add_definitions(-DCAST_OP)
endif()
if (QUANT_OP)
add_definitions(-DQUANT_OP)
endif()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册