提交 91f9e364 编写于 作者: D dolphin8

Merge branch 'opencl' of https://github.com/PaddlePaddle/paddle-mobile into opencl

......@@ -7,7 +7,7 @@ option(DEBUGING "enable debug mode" ON)
option(USE_EXCEPTION "use std exception" OFF)
option(LOG_PROFILE "log profile" OFF)
# select the platform to build
option(CPU "armv7 with neon" ON)
option(CPU "armv7 with neon" OFF)
option(GPU_MALI "mali gpu" OFF)
option(GPU_CL "opencl gpu" ON)
option(FPGA "fpga" OFF)
......
......@@ -46,7 +46,8 @@ struct PaddleMobileException : public std::exception {
std::string detail(buffer); \
throw paddle_mobile::PaddleMobileException("Custom Exception", buffer, \
__FILE__, __LINE__); \
}
} \
exit(0);
#define PADDLE_MOBILE_ENFORCE(stat, ...) \
{ \
......
......@@ -39,7 +39,13 @@ struct PrecisionTrait<Precision::FP16> {
};
//! device type
enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2, kGPU_CL = 3};
enum DeviceTypeEnum {
kINVALID = -1,
kCPU = 0,
kFPGA = 1,
kGPU_MALI = 2,
kGPU_CL = 3
};
template <DeviceTypeEnum T>
struct DeviceType {};
......@@ -49,7 +55,6 @@ typedef DeviceType<kFPGA> FPGA;
typedef DeviceType<kGPU_MALI> GPU_MALI;
typedef DeviceType<kGPU_CL> GPU_CL;
//! data type
enum DataType {
PM_INVALID = -1,
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/cl/cl_engine.h"
#include "CL/cl.h"
#include "framework/cl/cl_tool.h"
#include "framework/cl/cl_engine.h"
#include <cstdlib>
#include <cstring>
......@@ -28,11 +28,11 @@ bool CLEngine::Init() {
SetClDeviceId();
initialized_ = true;
// setClContext();
// setClCommandQueue();
// std::string filename = "./HelloWorld_Kernel.cl";
// loadKernelFromFile(filename.c_str());
// buildProgram();
// setClContext();
// setClCommandQueue();
// std::string filename = "./HelloWorld_Kernel.cl";
// loadKernelFromFile(filename.c_str());
// buildProgram();
}
CLEngine *CLEngine::Instance() {
......@@ -74,26 +74,26 @@ bool CLEngine::SetClDeviceId() {
return false;
}
//std::unique_ptr<_cl_kernel, clKernel_deleter> CLEngine::GSetKernel(
// std::unique_ptr<_cl_kernel, clKernel_deleter> CLEngine::GSetKernel(
// const std::string &kernel_name) {
// std::unique_ptr<_cl_kernel, clKernel_deleter> kernel(
// clCreateKernel(program_.get(), kernel_name.c_str(), NULL));
// return std::move(kernel);
//}
//
//bool CLEngine::SetClCommandQueue() {
// bool CLEngine::SetClCommandQueue() {
// cl_int status;
// command_queue_.reset(
// clCreateCommandQueue(context_.get(), devices_[0], 0, &status));
// return true;
//}
//bool CLEngine::SetClContext() {
// bool CLEngine::SetClContext() {
// context_.reset(clCreateContext(NULL, 1, devices_, NULL, NULL, NULL));
// return true;
//}
//bool CLEngine::LoadKernelFromFile(const char *kernel_file) {
// bool CLEngine::LoadKernelFromFile(const char *kernel_file) {
// size_t size;
// char *str;
// std::fstream f(kernel_file, (std::fstream::in | std::fstream::binary));
......@@ -118,10 +118,10 @@ bool CLEngine::SetClDeviceId() {
// const char *source = str;
// size_t sourceSize[] = {strlen(source)};
// program_.reset(
// clCreateProgramWithSource(context_.get(), 1, &source, sourceSize, NULL));
// clCreateProgramWithSource(context_.get(), 1, &source, sourceSize,
// NULL));
// return true;
//}
} // namespace framework
} // namespace paddle_mobile
......@@ -17,9 +17,10 @@ limitations under the License. */
#include <memory>
#include <string>
#include "CL/cl.h"
#include "common/enforce.h"
#include "framework/cl/cl_deleter.h"
#include "CL/cl.h"
#include "framework/cl/cl_tool.h"
namespace paddle_mobile {
namespace framework {
......@@ -36,16 +37,18 @@ class CLEngine {
return std::move(context_ptr);
}
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> CreateClCommandQueue() {
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter>
CreateClCommandQueue() {
cl_int status;
cl_command_queue queue = clCreateCommandQueue(context_.get(), devices_[0], 0, &status);
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_ptr(queue);
cl_command_queue queue =
clCreateCommandQueue(context_.get(), devices_[0], 0, &status);
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_ptr(
queue);
return std::move(command_queue_ptr);
}
std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith(cl_context context, std::string file_name) {
std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith(
cl_context context, std::string file_name) {
FILE *file = fopen(file_name.c_str(), "rb");
PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ",
filename.c_str());
......@@ -62,7 +65,8 @@ class CLEngine {
const char *source = data;
size_t sourceSize[] = {strlen(source)};
cl_program p = clCreateProgramWithSource(context, 1, &source, sourceSize, NULL);
cl_program p =
clCreateProgramWithSource(context, 1, &source, sourceSize, NULL);
std::unique_ptr<_cl_program, CLProgramDeleter> program_ptr(p);
return std::move(program_ptr);
}
......@@ -81,7 +85,6 @@ class CLEngine {
bool SetClDeviceId();
bool initialized_;
cl_platform_id platform_;
......@@ -94,14 +97,13 @@ class CLEngine {
std::unique_ptr<_cl_program, CLProgramDeleter> program_;
// bool SetClContext();
// bool SetClCommandQueue();
// bool SetClContext();
// bool LoadKernelFromFile(const char *kernel_file);
// bool SetClCommandQueue();
// bool BuildProgram();
// bool LoadKernelFromFile(const char *kernel_file);
// bool BuildProgram();
};
} // namespace framework
......
此差异已折叠。
......@@ -18,4 +18,4 @@ limitations under the License. */
typedef uint16_t half_t;
half_t float2half(float f);
float half2float(half_t h);
\ No newline at end of file
float half2float(half_t h);
......@@ -14,11 +14,13 @@ limitations under the License. */
#pragma once
#include <vector>
#include <string>
#include <type_traits>
#include <vector>
#include "framework/cl/cl_scope.h"
#include "framework/cl/cl_deleter.h"
#include "framework/cl/cl_image.h"
#include "framework/cl/cl_scope.h"
namespace paddle_mobile {
namespace framework {
......@@ -27,24 +29,38 @@ class CLHelper {
public:
CLHelper() = default;
CLHelper(CLScope *scope): scope_(scope) {
}
explicit CLHelper(CLScope *scope) : scope_(scope) {}
void AddKernel(const std::string &kernel_name, const std::string &file_name) {
auto kernel = scope_->GetKernel(kernel_name, file_name);
kernels.emplace_back(std::move(kernel));
}
cl_kernel KernelAt(const int index) {
return kernels[index].get();
}
cl_kernel KernelAt(const int index) { return kernels[index].get(); }
cl_command_queue CLCommandQueue() {
return scope_->CommandQueue();
}
cl_command_queue CLCommandQueue() { return scope_->CommandQueue(); }
cl_context CLContext() { return scope_->Context(); }
std::vector<size_t> DefaultWorkSize(const CLImage &image) {
// n c h w
auto image_dim = image.dims();
if (image_dim.size() == 4) {
auto n = image_dim[0];
auto h = image_dim[2];
auto w = image_dim[3];
auto image_width = image.ImageWidth();
auto work_size_0 = image_width / w;
auto work_size_1 = w;
auto work_size_2 = n * h;
cl_context CLContext() {
return scope_->Context();
return {work_size_0, work_size_1, work_size_2};
}
PADDLE_MOBILE_THROW_EXCEPTION("not support this dim, need imp");
}
private:
......@@ -52,5 +68,5 @@ class CLHelper {
std::vector<std::unique_ptr<_cl_kernel, CLKernelDeleter>> kernels;
};
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -14,10 +14,12 @@ limitations under the License. */
#pragma once
#include <vector>
#include "CL/cl.h"
#include "framework/cl/cl_half.h"
#include "framework/ddim.h"
#include "framework/tensor.h"
#include "CL/cl.h"
#include "cl_half.h"
namespace paddle_mobile {
namespace framework {
......@@ -27,18 +29,44 @@ class CLImage {
CLImage() = default;
void Init(cl_context context, float *tensorInput, DDim ddim) {
cl_image_format cf = {
.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT
};
tensor_dims_ = ddim;
cl_image_format cf = {.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT};
// NCHW -> [W * (C+3)/4, H * N]
size_t N = tensorDims_[0];
size_t C = tensorDims_[1];
size_t H = tensorDims_[2];
size_t W = tensorDims_[3];
DLOG << tensor_dims_;
size_t N, C, H, W;
if (tensor_dims_.size() == 4) {
N = tensor_dims_[0];
if (N < 0) {
N = 1;
}
C = tensor_dims_[1];
H = tensor_dims_[2];
W = tensor_dims_[3];
width_of_one_block_ = W;
height_of_one_block_ = H;
} else if (tensor_dims_.size() == 1) {
N = 1;
C = tensor_dims_[0];
H = 1;
W = 1;
width_of_one_block_ = W;
height_of_one_block_ = H;
}
DLOG << "-------InitMemory-------";
size_t width = W * ((C + 3) / 4);
size_t height = H * N;
image_width_ = width;
image_height_ = height;
std::unique_ptr<half_t[]> imageData{};
int count = 0;
if (tensorInput != nullptr) {
imageData.reset(new half_t[width * height * 4]);
float *p = tensorInput;
......@@ -47,11 +75,19 @@ class CLImage {
for (int c = 0; c < C; c++) {
size_t i1 = i0;
for (int h = 0; h < H; h++) {
size_t i2 = i1 << 2 + c % 4;
size_t i2 = (i1 << 2) + c % 4;
for (int w = 0; w < W; w++) {
if (i2 >= width * height * 4) {
printf("%d > %d ----> %d, %d, %d, %d --- %d, %d, %d\n", i2,
width * height * 4, n, c, h, w, i0, i1, i2);
}
assert(i2 < width * height * 4);
imageData[i2] = float2half(*p);
i2 += 4;
p++;
// count++;
// DLOG<<count;
}
i1 += width;
}
......@@ -59,57 +95,74 @@ class CLImage {
i0 += width * H;
}
}
DLOG << "-------InitMemory-------";
cl_int err;
cl_image_ = clCreateImage2D(
context, // cl_context context
CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, // cl_mem_flags flags
&cf, // const cl_image_format *image_format
width, // size_t image_width
height, // size_t image_height
0, // size_t image_row_pitch
reinterpret_cast<void*>(imageData.get()), // void *host_ptr
&err // cl_int *errcode_ret
);
context, // cl_context context
CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, // cl_mem_flags flags
&cf, // const cl_image_format *image_format
width, // size_t image_width
height, // size_t image_height
0, // size_t image_row_pitch
reinterpret_cast<void *>(imageData.get()), // void *host_ptr
&err);
if (err != CL_SUCCESS) {
// TODO: error handling
// TODO(HaiPeng): error handling
PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error ");
}
}
void Init(cl_context context, DDim ddim) {
Init(context, nullptr, ddim);
initialized_ = true;
}
void Init(cl_context context, DDim ddim) { Init(context, nullptr, ddim); }
inline CLImage &Resize(const DDim &dims) {
tensorDims_ = dims;
tensor_dims_ = dims;
return *this;
}
const DDim &dims() const {
return tensorDims_;
}
const DDim &dims() const { return tensor_dims_; }
std::vector<size_t> DefaultWorkSize() {
return {};
}
cl_mem GetCLImage() const { return cl_image_; }
cl_mem GetCLImage() {
return cl_image_;
template <typename T>
T *data() const {
return reinterpret_cast<T *>(tensor_input_);
}
inline int64_t numel() const { return product(tensor_dims_); }
inline size_t ImageWidth() const { return image_width_; }
inline size_t ImageHeight() const { return image_height_; }
inline size_t CBlock() const { return c_block_; }
inline size_t WidthOfOneBlock() const { return width_of_one_block_; }
inline size_t HeightOfOneBlock() const { return height_of_one_block_; }
private:
bool initialized_ = false;
cl_mem cl_image_;
DDim tensorDims_;
size_t image_width_;
size_t width_of_one_block_;
size_t height_of_one_block_;
size_t image_height_;
size_t c_block_;
DDim tensor_dims_;
float *tensor_input_;
cl_context context_;
};
//void TensorToCLImage(Tensor *tensor, CLImage *image) {
// void TensorToCLImage(Tensor *tensor, CLImage *image) {
//
//}
//
//void CLImageToTensor(CLImage *image, Tensor *tensor) {
// void CLImageToTensor(CLImage *image, Tensor *tensor) {
//
//}
}
}
\ No newline at end of file
} // namespace framework
} // namespace paddle_mobile
......@@ -18,10 +18,10 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include "framework/cl/cl_tool.h"
#include "framework/cl/cl_engine.h"
#include "framework/cl/cl_deleter.h"
#include "CL/cl.h"
#include "framework/cl/cl_deleter.h"
#include "framework/cl/cl_engine.h"
#include "framework/cl/cl_tool.h"
namespace paddle_mobile {
namespace framework {
......@@ -35,19 +35,17 @@ class CLScope {
command_queue_ = engin->CreateClCommandQueue();
}
cl_command_queue CommandQueue() {
return command_queue_.get();
}
cl_command_queue CommandQueue() { return command_queue_.get(); }
std::unique_ptr<_cl_kernel, CLKernelDeleter> GetKernel(const std::string &kernel_name, const std::string &file_name) {
std::unique_ptr<_cl_kernel, CLKernelDeleter> GetKernel(
const std::string &kernel_name, const std::string &file_name) {
auto program = Program(file_name);
std::unique_ptr<_cl_kernel, CLKernelDeleter> kernel(clCreateKernel(program, kernel_name.c_str(), NULL));
std::unique_ptr<_cl_kernel, CLKernelDeleter> kernel(
clCreateKernel(program, kernel_name.c_str(), NULL));
return std::move(kernel);
}
cl_context Context() {
return context_.get();
}
cl_context Context() { return context_.get(); }
cl_program Program(const std::string &file_name) {
auto it = programs_.find(file_name);
......@@ -55,20 +53,23 @@ class CLScope {
return it->second.get();
}
auto program = CLEngine::Instance()->CreateProgramWith(context_.get(), file_name);
auto program =
CLEngine::Instance()->CreateProgramWith(context_.get(), file_name);
programs_[file_name] = std::move(program);
status_ = clBuildProgram(program.get(), 0, 0, 0, 0, 0);
status_ = clBuildProgram(program.get(), 0, 0, 0, 0, 0);
CL_CHECK_ERRORS(status_);
return program.get();
}
private:
cl_int status_;
cl_int status_;
std::unique_ptr<_cl_context, CLContextDeleter> context_;
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_;
std::unordered_map<std::string, std::unique_ptr<_cl_program, CLProgramDeleter>> programs_;
std::unordered_map<std::string,
std::unique_ptr<_cl_program, CLProgramDeleter>>
programs_;
};
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -18,17 +18,17 @@ limitations under the License. */
#include <string>
#include <vector>
#include "framework/tensor_base.h"
#include "framework/cl/cl_engine.h"
#include "framework/cl/cl_deleter.h"
#include "CL/cl.h"
#include "framework/cl/cl_deleter.h"
#include "framework/cl/cl_engine.h"
#include "framework/tensor_base.h"
namespace paddle_mobile {
namespace framework {
class CLTensor : TensorBase {
public:
CLTensor(cl_context context) : context_(context) {}
explicit CLTensor(cl_context context) : context_(context) {}
/*! Resize the dimensions of the memory block. */
inline CLTensor &Resize(const DDim &dims) {
......@@ -84,7 +84,6 @@ class CLTensor : TensorBase {
}
private:
cl_context context_;
/*
......@@ -99,18 +98,15 @@ class CLTensor : TensorBase {
virtual void set_type(std::type_index type) = 0;
* */
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, void *input, std::type_index type, cl_context context)
: ptr_(clCreateBuffer(context,
CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, size,
reinterpret_cast<void *>(input), NULL)),
PlaceholderImpl(size_t size, void *input, std::type_index type,
cl_context context)
: ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
size, reinterpret_cast<void *>(input), NULL)),
size_(size),
type_(type) {
}
type_(type) {}
PlaceholderImpl(size_t size, std::type_index type, cl_context context)
: ptr_(clCreateBuffer(context,
CL_MEM_READ_WRITE, size, NULL, NULL)),
: ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)),
size_(size),
type_(type) {}
......@@ -128,9 +124,7 @@ class CLTensor : TensorBase {
/* the current type of memory */
std::type_index type_;
};
};
} // namespace framework
......
......@@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cl_tool.h"
#include "framework/cl/cl_tool.h"
namespace paddle_mobile {
namespace framework {
const char *opencl_error_to_str(cl_int error) {
#define CASE_CL_CONSTANT(NAME) case NAME: return #NAME;
#define CASE_CL_CONSTANT(NAME) \
case NAME: \
return #NAME;
// Suppose that no combinations are possible.
switch (error) {
CASE_CL_CONSTANT(CL_SUCCESS)
......@@ -78,5 +80,5 @@ const char *opencl_error_to_str(cl_int error) {
#undef CASE_CL_CONSTANT
}
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -19,16 +19,15 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
const char* opencl_error_to_str (cl_int error);
#define CL_CHECK_ERRORS(ERR) \
if(ERR != CL_SUCCESS) \
{ \
printf( \
"OpenCL error with code %s happened in file %s at line %d. Exiting.\n", \
opencl_error_to_str(ERR), __FILE__, __LINE__ \
); \
}
}
}
const char* opencl_error_to_str(cl_int error);
#define CL_CHECK_ERRORS(ERR) \
if (ERR != CL_SUCCESS) { \
printf( \
"OpenCL error with code %s happened in file %s at line %d. " \
"Exiting.\n", \
opencl_error_to_str(ERR), __FILE__, __LINE__); \
}
} // namespace framework
} // namespace paddle_mobile
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "executor.h"
#include "framework/executor.h"
#include <operators/math/gemm.h>
#include <algorithm>
#include <vector>
......@@ -265,7 +265,7 @@ void Executor<Dtype, P>::InitCombineMemory() {
char *origin_data;
if (program_.combined_params_buf && program_.combined_params_len) {
LOG(kLOG_INFO) << "use outter memory";
origin_data = (char *)program_.combined_params_buf;
origin_data = reinterpret_cast<char *>(program_.combined_params_buf);
} else {
LOG(kLOG_INFO) << " begin init combine memory";
origin_data = Get_binary_data(program_.para_path);
......@@ -666,12 +666,12 @@ void Executor<Dtype, P>::InjectVariable(const framework::Tensor &t,
g_feed_value->GetMutable<framework::LoDTensor>();
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) {
InjectVariable(t, "feed");
};
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
......@@ -687,14 +687,14 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
auto *output_tensor = framework::GetVarValue<framework::LoDTensor>(
out_keys[0], output_map, *(program_.scope));
return std::make_shared<framework::Tensor>(framework::Tensor(*output_tensor));
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From_To(int start, int end) {
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
end = end < 0 ? (int)ops.size() : end;
end = end < 0 ? static_cast<int>(ops.size()) : end;
PADDLE_MOBILE_ENFORCE(start >= 0 && start < end && end <= ops.size(),
"start or end parameter is wrong");
......@@ -715,17 +715,17 @@ void Executor<Dtype, P>::Predict_From_To(int start, int end) {
profile[i].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
}
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From(int start) {
Predict_From_To(start);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_To(int end) {
Predict_From_To(0, end);
};
}
#endif
#ifdef PADDLE_MOBILE_FPGA
......@@ -738,12 +738,12 @@ void Executor<Dtype, P>::InjectVariable(const framework::Tensor &t,
g_feed_value->GetMutable<framework::LoDTensor>();
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) {
InjectVariable(t, "feed");
};
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
......@@ -759,14 +759,14 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
auto *output_tensor = framework::GetVarValue<framework::LoDTensor>(
out_keys[0], output_map, *(program_.scope));
return std::make_shared<framework::Tensor>(framework::Tensor(*output_tensor));
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From_To(int start, int end) {
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
end = end < 0 ? (int)ops.size() : end;
end = end < 0 ? static_cast<int>(ops.size()) : end;
PADDLE_MOBILE_ENFORCE(start >= 0 && start < end && end <= ops.size(),
"start or end parameter is wrong");
......@@ -787,20 +787,120 @@ void Executor<Dtype, P>::Predict_From_To(int start, int end) {
profile[i].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
}
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From(int start) {
Predict_From_To(start);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_To(int end) {
Predict_From_To(0, end);
};
}
#endif
#ifdef PADDLE_MOBILE_CL
template <>
void Executor<GPU_CL, Precision::FP32>::LoadMemory(
const framework::VarDesc var_desc, float *tensorInput, char **data) {
// 1. version
uint32_t version = *reinterpret_cast<uint32_t *>(*data);
(*data) += sizeof(uint32_t);
// 2 Lod information
uint64_t *lod_level_ptr = new uint64_t();
memcpy(lod_level_ptr, (*data), sizeof(uint64_t));
uint64_t lod_level = *lod_level_ptr;
delete lod_level_ptr;
(*data) += sizeof(uint64_t);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size = *reinterpret_cast<uint64_t *>(*data);
(*data) += sizeof(uint64_t);
std::vector<size_t> tmp(size / sizeof(size_t));
for (int k = 0; k < tmp.size(); ++k) {
tmp[k] = *reinterpret_cast<size_t *>(*data);
(*data) += sizeof(size_t);
}
}
// 3. tensor version
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(*data);
(*data) += sizeof(uint32_t);
// 4. tensor desc
int32_t size = *reinterpret_cast<int32_t *>(*data);
(*data) += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
buf.get()[m] = (*data)[m];
}
(*data) += (sizeof(char) * size);
const framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
}
void *memory = nullptr;
// int type_size = 0;
// switch (desc.DataType()) {
// case framework::VARTYPE_TYPE_FP16:
// type_size = 2;
// break;
// case framework::VARTYPE_TYPE_FP32:
// type_size = 4;
// memory = tensor->mutable_data<float>();
// break;
// case framework::VARTYPE_TYPE_FP64:
// type_size = 8;
// break;
// case framework::VARTYPE_TYPE_INT32:
// memory = tensor->mutable_data<int32_t>();
// type_size = 4;
// break;
// case framework::VARTYPE_TYPE_INT64:
// type_size = 8;
// break;
// case framework::VARTYPE_TYPE_BOOL:
// type_size = 1;
// break;
// default:
// break;
// }
int type_size = 4;
memory = tensorInput;
if (program_.quantification) {
float min_value;
float max_value;
memcpy(&min_value, *data, sizeof(float));
memcpy(&max_value, *data + sizeof(float), sizeof(float));
*data += 2 * sizeof(float);
const float factor = (max_value - min_value) / 255.0;
uint8_t *uint8_data = reinterpret_cast<uint8_t *>(*data);
for (int k = 0; k < memory_size; ++k) {
static_cast<float *>(memory)[k] = uint8_data[k] * factor + min_value;
}
*data += (memory_size * sizeof(uint8_t));
} else {
for (int n = 0; n < memory_size; n++) {
float value;
memcpy(&value, *data + n * type_size, type_size);
if (value < 1e-30 && value > -1e-30) {
static_cast<float *>(memory)[n] = 0.0;
} else {
static_cast<float *>(memory)[n] = value;
}
}
(*data) += (sizeof(char) * memory_size * type_size);
}
}
template <>
void Executor<GPU_CL, Precision::FP32>::InitMemory() {
......@@ -812,27 +912,37 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") {
continue;
}
char *origin_data =
Get_binary_data(program_.model_path + "/" + var_desc->Name());
char *data = origin_data;
cl_context context = program_.scope->GetCLScpoe()->Context();
float *tensorInput = (float *)origin_data;
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = cl_image->dims();
int numel = 1;
for (auto l : desc.Dims()) {
numel *= l;
}
DLOG << var_desc->Name();
float *tensorInput = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
LoadMemory(*var_desc, tensorInput, &data);
framework::DDim ddim = framework::make_ddim(desc.Dims());
cl_image->Init(context, tensorInput, ddim);
delete origin_data;
}else{
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = cl_image->dims();
delete origin_data;
paddle_mobile::memory::Free(tensorInput);
} else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
cl_image->Init(context, ddim);
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
DLOG << var_desc->Name();
cl_image->Init(context, ddim);
}
}
}
}
......@@ -843,13 +953,13 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
char *origin_data;
if (program_.combined_params_buf && program_.combined_params_len) {
LOG(kLOG_INFO) << "use outter memory";
origin_data = (char *)program_.combined_params_buf;
origin_data = reinterpret_cast<char *>(program_.combined_params_buf);
} else {
LOG(kLOG_INFO) << " begin init combine memory";
origin_data = Get_binary_data(program_.para_path);
}
PADDLE_MOBILE_ENFORCE(origin_data != nullptr, "origin_data==nullptr!!!");
float *data = (float *)origin_data;
float *data = reinterpret_cast<float *>(origin_data);
for (const auto &block : to_predict_program_->Blocks()) {
for (const auto &var_desc : block->Vars()) {
......@@ -863,21 +973,23 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
cl_context context = program_.scope->GetCLScpoe()->Context();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = cl_image->dims();
framework::DDim ddim = framework::make_ddim(desc.Dims());
int numel = 1;
for (int i = 0; i < ddim.size(); i++) {
numel = numel * ddim[i];
}
float *tensorInput = data;
float *tensorInput = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
LoadMemory(*var_desc, tensorInput, &origin_data);
cl_image->Init(context, tensorInput, ddim);
data += numel;
}else{
paddle_mobile::memory::Free(tensorInput);
} else {
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = cl_image->dims();
framework::DDim ddim = framework::make_ddim(desc.Dims());
cl_image->Init(context, ddim);
}
......
......@@ -35,7 +35,7 @@ using std::string;
namespace paddle_mobile {
namespace framework {
template<typename Dtype = CPU, Precision P = Precision::FP32>
template <typename Dtype = CPU, Precision P = Precision::FP32>
class Executor {
public:
typedef typename PrecisionTrait<P>::ptype Ptype;
......@@ -56,7 +56,7 @@ class Executor {
* @b to predict
* */
std::shared_ptr<framework::LoDTensor> PredictLod(
const framework::LoDTensor &t);
const framework::LoDTensor &t);
/*
* @b to predict with vector and dim
......@@ -73,6 +73,8 @@ class Executor {
void LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor, char **data);
void LoadMemory(const framework::VarDesc var_desc, float *tensorInput,
char **data);
void InitCombineMemory();
......@@ -84,8 +86,8 @@ class Executor {
int block_id);
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<framework::OperatorBase<Dtype>>>>
ops_of_block_;
std::vector<std::shared_ptr<framework::OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
bool loddable_ = false;
#ifdef PADDLE_EXECUTOR_MULTITHREAD
......@@ -105,15 +107,15 @@ class Executor {
#ifdef PADDLE_MOBILE_FPGA
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult(int id = -1);
void Predict_From_To(int start = 0, int end = -1);
void Predict_From(int start);
void Predict_To(int end);
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult(int id = -1);
void Predict_From_To(int start = 0, int end = -1);
void Predict_From(int start);
void Predict_To(int end);
#endif
};
}
} // namespace framework
} // namespace paddle_mobile
......@@ -12,10 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "loader.h"
#include "framework/loader.h"
#include "framework/lod_tensor.h"
#include "framework/program/program-optimize/program_optimize.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_image.h"
#endif
namespace paddle_mobile {
namespace framework {
......@@ -26,9 +29,10 @@ namespace framework {
* @param originProgramDesc
* @param scope
*/
void InitMemoryFromProgram(
std::shared_ptr<ProgramDesc> &originProgramDesc,
std::shared_ptr<Scope> &scope) {
template <typename Dtype, Precision P>
void Loader<Dtype, P>::InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope) {
for (const auto &block : originProgramDesc.get()->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = scope.get()->Var(var_desc->Name());
......@@ -51,6 +55,35 @@ void InitMemoryFromProgram(
}
}
#ifdef PADDLE_MOBILE_CL
template <>
void Loader<GPU_CL, Precision::FP32>::InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope) {
for (const auto &block : originProgramDesc.get()->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = scope.get()->Var(var_desc->Name());
if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
if (var_desc->Persistable()) {
auto dim = var_desc->Tensor_desc().Dims();
// auto tensor = var->GetMutable<LoDTensor>();
auto cl_image = var->GetMutable<framework::CLImage>();
cl_image->Resize(make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
dim[0] = 1;
auto cl_image = var->GetMutable<framework::CLImage>();
cl_image->Resize(make_ddim(dim));
}
} else {
// TODO(codeWorm): some.
}
}
}
}
#endif
/**
* fusion and print someinfos
* @tparam Dtype
......@@ -60,17 +93,17 @@ void InitMemoryFromProgram(
* @param program
* @param originProgramDesc
*/
template<typename Dtype, Precision P>
template <typename Dtype, Precision P>
void FusionAndPrintInfos(
bool &optimize, bool &can_add_split, Program<Dtype, P> &program,
const std::shared_ptr<ProgramDesc> &originProgramDesc) {
bool optimize, bool can_add_split, Program<Dtype, P> *program,
const std::shared_ptr<ProgramDesc> &originProgramDesc) {
if (optimize) {
ProgramOptimize program_optimize;
program.optimizeProgram =
program_optimize.FusionOptimize(originProgramDesc, can_add_split);
program->optimizeProgram =
program_optimize.FusionOptimize(originProgramDesc, can_add_split);
}
if (optimize) {
program.optimizeProgram->Description("optimize: ");
program->optimizeProgram->Description("optimize: ");
} else {
originProgramDesc->Description("program: ");
}
......@@ -98,20 +131,22 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) {
return cur_len;
}
template<typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &dirname, bool optimize, bool quantification,
bool can_add_split) {
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(const std::string &dirname,
bool optimize,
bool quantification,
bool can_add_split) {
auto program = this->LoadProgram(dirname + "/__model__", optimize,
quantification, can_add_split);
program.model_path = dirname;
return program;
}
template<typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &model_path, const std::string &para_path, bool optimize,
bool quantification) {
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize,
bool quantification) {
auto program = this->LoadProgram(model_path, optimize, quantification);
program.para_path = para_path;
......@@ -120,10 +155,10 @@ const Program<Dtype, P> Loader<Dtype, P>::Load(
return program;
}
template<typename Dtype, Precision P>
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
const std::string &model_path, bool optimize, bool quantification,
bool can_add_split) {
const std::string &model_path, bool optimize, bool quantification,
bool can_add_split) {
std::string model_filename = model_path;
PaddleMobile__Framework__Proto__ProgramDesc *c_program;
uint8_t *buf = NULL;
......@@ -132,7 +167,7 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
PADDLE_MOBILE_ENFORCE(buf != NULL, "read from __model__ is null");
c_program = paddle_mobile__framework__proto__program_desc__unpack(
NULL, read_size, buf);
NULL, read_size, buf);
//
PADDLE_MOBILE_ENFORCE(c_program != NULL, "program is null");
//
......@@ -151,23 +186,23 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
// use originProgramDesc and scope to init tensors
InitMemoryFromProgram(originProgramDesc, scope);
// perform fusion and print infos
FusionAndPrintInfos(optimize, can_add_split, program, originProgramDesc);
FusionAndPrintInfos(optimize, can_add_split, &program, originProgramDesc);
paddle_mobile__framework__proto__program_desc__free_unpacked(c_program, NULL);
return program;
}
template<typename Dtype, Precision P>
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::LoadCombinedMemory(
size_t read_size, const uint8_t *buf, size_t combined_params_len,
const uint8_t *combined_params_buf, bool optimize, bool quantification) {
size_t read_size, const uint8_t *buf, size_t combined_params_len,
uint8_t *combined_params_buf, bool optimize, bool quantification) {
bool can_add_split = false;
PaddleMobile__Framework__Proto__ProgramDesc *c_program;
PADDLE_MOBILE_ENFORCE(buf != nullptr, "read from __model__ is null");
c_program = paddle_mobile__framework__proto__program_desc__unpack(
nullptr, read_size, buf);
nullptr, read_size, buf);
//
PADDLE_MOBILE_ENFORCE(c_program != nullptr, "program is null");
//
......@@ -186,23 +221,19 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadCombinedMemory(
auto scope = std::make_shared<Scope>();
program.scope = scope;
InitMemoryFromProgram(originProgramDesc, scope);
FusionAndPrintInfos(optimize, can_add_split, program, originProgramDesc);
FusionAndPrintInfos(optimize, can_add_split, &program, originProgramDesc);
paddle_mobile__framework__proto__program_desc__free_unpacked(c_program,
nullptr);
return program;
}
template
class Loader<CPU, Precision::FP32>;
template class Loader<CPU, Precision::FP32>;
template
class Loader<FPGA, Precision::FP32>;
template class Loader<FPGA, Precision::FP32>;
template
class Loader<GPU_MALI, Precision::FP32>;
template class Loader<GPU_MALI, Precision::FP32>;
template
class Loader<GPU_CL, Precision::FP32>;
template class Loader<GPU_CL, Precision::FP32>;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "framework/program/program.h"
namespace paddle_mobile {
namespace framework{
namespace framework {
template <typename Dtype = CPU, Precision P = Precision::FP32>
class Loader {
......@@ -30,30 +30,36 @@ class Loader {
* @b 加载分开形式的 fluid 模型
* */
const Program<Dtype, P> Load(const std::string &dirname,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
/*
* @b load combine format fluid mode
* @b 加载结合在一起格式的模型
* */
const Program<Dtype, P> Load(const std::string &model_path,
const std::string &para_path,
bool optimize = false,
bool quantification = false);
const std::string &para_path,
bool optimize = false,
bool quantification = false);
const Program<Dtype, P> LoadCombinedMemory(
size_t model_len, const uint8_t *model_buf, size_t combined_params_len,
const uint8_t *combined_params_buf, bool optimize = false,
bool quantification = false);
const Program<Dtype, P> LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
size_t combined_params_len,
uint8_t *combined_params_buf,
bool optimize = false,
bool quantification = false);
private:
const Program<Dtype, P> LoadProgram(const std::string &model_path,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
void InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope);
};
}
} // namespace framework
} // namespace paddle_mobile
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <memory>
#include <string>
#include <tuple>
#include "common/log.h"
......@@ -92,7 +92,6 @@ class OpRegistry {
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap attrs,
std::shared_ptr<paddle_mobile::framework::Scope> scope) {
auto& info = OpInfoMap<Dtype>::Instance()->Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "common/enforce.h"
......@@ -32,8 +33,8 @@ limitations under the License. */
#include "framework/tensor.h"
#include "framework/variable.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_scope.h"
#include "framework/cl/cl_helper.h"
#include "framework/cl/cl_scope.h"
#endif
namespace paddle_mobile {
namespace framework {
......@@ -131,7 +132,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
// DLOG << i.second;
// }
PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed",
this->type_.c_str());
}
......@@ -147,7 +147,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
template <typename Dtype, typename P>
class OpKernelBase {
public:
OpKernelBase() = default;
#ifdef PADDLE_MOBILE_CL
......@@ -156,11 +155,11 @@ class OpKernelBase {
}
#endif
/*
* @b 所有kernel 需实现 Compute 方法
* @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
* 所有结构体存在与: paddle-mobile/src/operators/op_param.h
* */
/*
* @b 所有kernel 需实现 Compute 方法
* @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
* 所有结构体存在与: paddle-mobile/src/operators/op_param.h
* */
#ifdef PADDLE_McOBILE_MALI_GPU
OpKernelBase() { acl_op_ = nullptr; }
void *GetAclOp() const { return acl_op_; }
......@@ -181,8 +180,6 @@ class OpKernelBase {
#ifdef PADDLE_MOBILE_MALI_GPU
void *acl_op_;
#endif
};
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \
......
......@@ -32,7 +32,7 @@ class Program {
bool combined = false;
bool quantification = false;
size_t combined_params_len;
const uint8_t *combined_params_buf;
uint8_t *combined_params_buf;
private:
};
......
......@@ -15,13 +15,14 @@ limitations under the License. */
#pragma once
#include <list>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_MOBILE_CL
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_scope.h"
#endif
#include <unordered_map>
#include "variable.h"
#include "framework/variable.h"
namespace paddle_mobile {
namespace framework {
......@@ -42,7 +43,6 @@ class Scope {
#ifdef PADDLE_MOBILE_CL
delete cl_scope_;
#endif
}
Scope &NewScope() const;
......@@ -83,9 +83,7 @@ class Scope {
Variable *FindVarLocally(const std::string &name) const;
#ifdef PADDLE_MOBILE_CL
CLScope *GetCLScpoe() {
return cl_scope_;
}
CLScope *GetCLScpoe() { return cl_scope_; }
#endif
private:
......@@ -99,7 +97,6 @@ class Scope {
#ifdef PADDLE_MOBILE_CL
CLScope *cl_scope_ = new CLScope();
#endif
};
} // namespace framework
} // namespace paddle_mobile
......@@ -68,9 +68,10 @@ bool PaddleMobile<Dtype, P>::Load(const std::string &model_path,
}
template <typename Dtype, Precision P>
bool PaddleMobile<Dtype, P>::LoadCombinedMemory(
size_t model_len, const uint8_t *model_buf, size_t combined_params_len,
const uint8_t *combined_params_buf) {
bool PaddleMobile<Dtype, P>::LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
size_t combined_params_len,
uint8_t *combined_params_buf) {
int batch_size = 1;
bool optimise = true;
bool quantification = false;
......
......@@ -22,9 +22,9 @@ limitations under the License. */
#endif // _OPENMP
#include "common/types.h"
#include "framework/tensor.h"
#include "framework/executor.h"
#include "framework/loader.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......@@ -83,7 +83,7 @@ class PaddleMobile {
*/
bool LoadCombinedMemory(size_t model_len, const uint8_t *model_buf,
size_t combined_params_len,
const uint8_t *combined_params_buf);
uint8_t *combined_params_buf);
void Clear();
......@@ -94,6 +94,7 @@ class PaddleMobile {
std::shared_ptr<framework::Executor<Dtype, P>> executor_;
#ifdef PADDLE_MOBILE_FPGA
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
......
......@@ -40,4 +40,8 @@ REGISTER_OPERATOR_MALI_GPU(batch_norm, ops::BatchNormOp);
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(batch_norm, ops::BatchNormOp);
#endif
#endif
......@@ -54,5 +54,8 @@ USE_OP_MALI_GPU(batch_norm);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
USE_OP_CL(batch_norm);
#endif
#endif
......@@ -43,13 +43,14 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
#ifdef PADDLE_MOBILE_FPGA
void Init() {
void Init() {
Tensor *output = param_.Out();
fpga::format_fp16_ofm(output);
}
void RunImpl() const {
auto input = (Tensor *)const_cast<LoDTensor *>(param_.InputX());
auto input =
reinterpret_cast<Tensor *>(const_cast<LoDTensor *>(param_.InputX()));
auto input_ptr = input->data<float>();
fpga::format_image(input);
Tensor *output = param_.Out();
......@@ -61,7 +62,7 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
args.output_data_type = fpga::DATA_TYPE_FP16;
args.input_layout_type = fpga::LAYOUT_CHW;
args.output_layout_type = fpga::LAYOUT_HWC;
args.image.address = (void *)input_ptr;
args.image.address = reinterpret_cast<void *>(input_ptr);
args.image.channels = (uint32_t)input->dims()[1];
args.image.height = (uint32_t)input->dims()[2];
args.image.width = (uint32_t)input->dims()[3];
......@@ -74,13 +75,10 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
#else
#ifdef PADDLE_MOBILE_CL
void Init() {}
void RunImpl() {
}
void Init() {}
void RunImpl() {}
#else
void Init() {}
void Init() {}
void RunImpl() {
param_.Out()->ShareDataWith(*param_.InputX());
param_.Out()->set_lod(param_.InputX()->lod());
......
......@@ -43,3 +43,6 @@ REGISTER_OPERATOR_MALI_GPU(fetch, ops::FetchOp);
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(fetch, ops::FetchOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(fetch, ops::FetchOp);
#endif
......@@ -54,3 +54,6 @@ USE_OP_MALI_GPU(fetch);
#ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(fetch);
#endif
#ifdef PADDLE_MOBILE_CL
USE_OP_CL(fetch);
#endif
......@@ -20,8 +20,8 @@ limitations under the License. */
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "op_param.h"
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
......@@ -103,7 +103,7 @@ static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
#ifdef PADDLE_MOBILE_CL
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif
......
......@@ -26,8 +26,7 @@ bool BatchNormKernel<CPU, float>::Init(BatchNormParam<CPU> *param) {
}
template <>
void BatchNormKernel<CPU, float>::Compute(
const BatchNormParam<CPU> &param) {
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam<CPU> &param) {
BatchnormCompute<float>(param);
}
......
......@@ -26,8 +26,7 @@ bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam<CPU> *param) {
}
template <>
void BoxCoderKernel<CPU, float>::Compute(
const BoxCoderParam<CPU> &param) {
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam<CPU> &param) {
BoxCoderCompute<float>(param);
}
......
......@@ -25,8 +25,7 @@ bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
}
template <>
void ConvAddKernel<CPU, float>::Compute(
const FusionConvAddParam<CPU> &param) {
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
ConvAddCompute<float>(param);
}
......
......@@ -26,8 +26,7 @@ bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
}
template <>
void DepthwiseConvKernel<CPU, float>::Compute(
const ConvParam<CPU> &param) {
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConvCompute<float>(param);
}
......
......@@ -26,8 +26,7 @@ bool FusionFcKernel<CPU, float>::Init(FusionFcParam<CPU> *param) {
}
template <>
void FusionFcKernel<CPU, float>::Compute(
const FusionFcParam<CPU> &param) {
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam<CPU> &param) {
FusionFcCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
......
......@@ -26,8 +26,7 @@ bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam<CPU> *param) {
}
template <>
void PriorBoxKernel<CPU, float>::Compute(
const PriorBoxParam<CPU> &param) {
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam<CPU> &param) {
PriorBoxCompute<float>(param);
}
......
......@@ -25,8 +25,7 @@ bool TransposeKernel<CPU, float>::Init(TransposeParam<CPU> *param) {
}
template <>
void TransposeKernel<CPU, float>::Compute(
const TransposeParam<CPU> &param) {
void TransposeKernel<CPU, float>::Compute(const TransposeParam<CPU> &param) {
TransposeCompute<float>(param);
}
......
......@@ -33,4 +33,3 @@ inline hafl4 activation(half4 in
}
*/
......@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
conv
conv_bn
......@@ -30,7 +27,6 @@ conv_add_bn_relu
#include "common.h"
__kernel void conv_1x1(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
__kernel void elementwise_add(__global float* in, __global float* out) {
int num = get_global_id(0);
out[num] = in[num] * 0.1 + 102;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef FUSION_CONVADDBN_OP
#include "operators/kernel/conv_add_bn_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNReluKernel<GPU_CL, float>::Init(
FusionConvAddBNReluParam<GPU_CL> *param) {
return true;
}
template <>
void ConvAddBNReluKernel<GPU_CL, float>::Compute(
const FusionConvAddBNReluParam<GPU_CL> &param) {
}
template class ConvAddBNReluKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -15,20 +15,122 @@ limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "framework/cl/cl_image.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNReluKernel<GPU_CL, float>::Init(
FusionConvAddBNReluParam<GPU_CL> *param) {
FusionConvAddBNReluParam<GPU_CL> *param) {
// const CL *mean = param->InputMean();
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
delete[](new_scale_ptr);
delete[](new_bias_ptr);
framework::CLImage *new_scale = new framework::CLImage();
new_scale->Init(this->cl_helper_.CLContext(), new_scale_ptr,
variance->dims());
framework::CLImage *new_bias = new framework::CLImage();
new_bias->Init(this->cl_helper_.CLContext(), new_bias_ptr, variance->dims());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
if (param->Filter()->WidthOfOneBlock() == 1 &&
param->Filter()->HeightOfOneBlock() == 1) {
this->cl_helper_.AddKernel("conv_1x1", "conv_add_bn_relu_kernel.cl");
} else if (param->Filter()->dims()[1] == 1) {
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_add_bn_relu_kernel.cl");
} else if (param->Filter()->WidthOfOneBlock() == 3 &&
param->Filter()->HeightOfOneBlock() == 3) {
this->cl_helper_.AddKernel("conv_3x3", "conv_add_bn_relu_kernel.cl");
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true;
}
template <>
void ConvAddBNReluKernel<GPU_CL, float>::Compute(
const FusionConvAddBNReluParam<GPU_CL> &param) {
const FusionConvAddBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto biase = param.Bias()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = param.Input()->CBlock();
int input_width = param.Input()->WidthOfOneBlock();
int input_height = param.Input()->HeightOfOneBlock();
clSetKernelArg(kernel, 0, sizeof(int), &c_block);
clSetKernelArg(kernel, 1, sizeof(int), &w);
clSetKernelArg(kernel, 2, sizeof(int), &nh);
clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase);
clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_scale);
clSetKernelArg(kernel, 7, sizeof(cl_mem), &new_bias);
clSetKernelArg(kernel, 8, sizeof(cl_mem), &output);
clSetKernelArg(kernel, 9, sizeof(int), &stride);
clSetKernelArg(kernel, 10, sizeof(int), &offset);
clSetKernelArg(kernel, 11, sizeof(int), &input_c);
clSetKernelArg(kernel, 12, sizeof(int), &input_width);
clSetKernelArg(kernel, 13, sizeof(int), &input_height);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
}
template class ConvAddBNReluKernel<GPU_CL, float>;
} // namespace operators
......
......@@ -21,12 +21,62 @@ namespace operators {
template <>
bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) {
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
if (param->Filter()->WidthOfOneBlock() == 1 &&
param->Filter()->HeightOfOneBlock() == 1) {
this->cl_helper_.AddKernel("conv_1x1", "conv_add_bn_relu_kernel.cl");
} else if (param->Filter()->dims()[1] == 1) {
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_add_bn_relu_kernel.cl");
} else if (param->Filter()->WidthOfOneBlock() == 3 &&
param->Filter()->HeightOfOneBlock() == 3) {
this->cl_helper_.AddKernel("conv_3x3", "conv_add_bn_relu_kernel.cl");
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true;
}
template <>
void ConvAddKernel<GPU_CL, float>::Compute(
const FusionConvAddParam<GPU_CL> &param) {
const FusionConvAddParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto biase = param.Bias()->GetCLImage();
auto output = param.Output();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = param.Input()->CBlock();
int input_width = param.Input()->WidthOfOneBlock();
int input_height = param.Input()->HeightOfOneBlock();
clSetKernelArg(kernel, 0, sizeof(int), &c_block);
clSetKernelArg(kernel, 1, sizeof(int), &w);
clSetKernelArg(kernel, 2, sizeof(int), &nh);
clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase);
clSetKernelArg(kernel, 8, sizeof(cl_mem), &output);
clSetKernelArg(kernel, 9, sizeof(int), &stride);
clSetKernelArg(kernel, 10, sizeof(int), &offset);
clSetKernelArg(kernel, 11, sizeof(int), &input_c);
clSetKernelArg(kernel, 12, sizeof(int), &input_width);
clSetKernelArg(kernel, 13, sizeof(int), &input_height);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
}
template class ConvAddKernel<GPU_CL, float>;
......
......@@ -15,22 +15,72 @@ limitations under the License. */
#ifdef CONV_OP
#include "operators/kernel/conv_kernel.h"
#include "operators/kernel/central-arm-func/conv_arm_func.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("conv_3x3", "conv_kernel.cl");
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
if (param->Filter()->WidthOfOneBlock() == 1 &&
param->Filter()->HeightOfOneBlock() == 1) {
this->cl_helper_.AddKernel("conv_1x1", "conv_add_bn_relu_kernel.cl");
} else if (param->Filter()->dims()[1] == 1) {
this->cl_helper_.AddKernel("depth_conv_3x3", "conv_add_bn_relu_kernel.cl");
} else if (param->Filter()->WidthOfOneBlock() == 3 &&
param->Filter()->HeightOfOneBlock() == 3) {
this->cl_helper_.AddKernel("conv_3x3", "conv_add_bn_relu_kernel.cl");
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
}
return true;
}
template <>
void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
size_t global_work_size[3] = {1, 2, 3};
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto output = param.Output();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = param.Input()->CBlock();
int dilation = param.Dilations()[0];
int input_width = param.Input()->WidthOfOneBlock();
int input_height = param.Input()->HeightOfOneBlock();
clSetKernelArg(kernel, 0, sizeof(int), &c_block);
clSetKernelArg(kernel, 1, sizeof(int), &w);
clSetKernelArg(kernel, 2, sizeof(int), &nh);
clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
clSetKernelArg(kernel, 5, sizeof(cl_mem), &output);
clSetKernelArg(kernel, 6, sizeof(int), &stride);
clSetKernelArg(kernel, 7, sizeof(int), &offset);
clSetKernelArg(kernel, 8, sizeof(int), &input_c);
clSetKernelArg(kernel, 9, sizeof(int), &dilation);
clSetKernelArg(kernel, 10, sizeof(int), &input_width);
clSetKernelArg(kernel, 11, sizeof(int), &input_height);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
// auto kernel = this->cl_helper_.KernelAt(0);
// size_t global_work_size[3] = {1, 2, 3};
// clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
// global_work_size, NULL, 0, NULL, NULL);
}
template class ConvKernel<GPU_CL, float>;
......
......@@ -17,22 +17,23 @@ limitations under the License. */
#include "operators/kernel/elementwise_add_kernel.h"
namespace paddle_mobile {
namespace operators {
namespace operators {
template <>
bool ElementwiseAddKernel<GPU_CL, float>::Init(ElementwiseAddParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("elementwise_add", "elementwise_add_kernel.cl");
return true;
}
template <>
bool ElementwiseAddKernel<GPU_CL, float>::Init(
ElementwiseAddParam<GPU_CL> *param) {
// this->cl_helper_.AddKernel("elementwise_add",
// "elementwise_add_kernel.cl");
return true;
}
template <>
void ElementwiseAddKernel<GPU_CL, float>::Compute(const ElementwiseAddParam<GPU_CL> &param) {
template <>
void ElementwiseAddKernel<GPU_CL, float>::Compute(
const ElementwiseAddParam<GPU_CL> &param) {}
}
template class ElementwiseAddKernel<GPU_CL, float>;
template class ElementwiseAddKernel<GPU_CL, float>;
} // namespace operators
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/relu_kernel.h"
namespace paddle_mobile {
......@@ -30,4 +29,3 @@ template class ReluKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/reshape_kernel.h"
namespace paddle_mobile {
......@@ -30,4 +29,3 @@ template class ReshapeKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
......@@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SOFTMAX_OP
#include "operators/kernel/pool_kernel.h"
#include "operators/kernel/softmax_kernel.h"
namespace paddle_mobile {
namespace operators {
......@@ -30,4 +31,4 @@ template class SoftmaxKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -67,8 +67,7 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam<FPGA> *param) {
}
template <>
void ConvBNKernel<FPGA, float>::Compute(
const FusionConvBNParam<FPGA> &param) {
void ConvBNKernel<FPGA, float>::Compute(const FusionConvBNParam<FPGA> &param) {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
......
......@@ -26,8 +26,7 @@ bool DropoutKernel<FPGA, float>::Init(DropoutParam<FPGA> *param) {
}
template <>
void DropoutKernel<FPGA, float>::Compute(
const DropoutParam<FPGA> &param) {}
void DropoutKernel<FPGA, float>::Compute(const DropoutParam<FPGA> &param) {}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -60,8 +60,7 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
}
template <>
void FusionFcKernel<FPGA, float>::Compute(
const FusionFcParam<FPGA> &param) {
void FusionFcKernel<FPGA, float>::Compute(const FusionFcParam<FPGA> &param) {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
......
......@@ -47,8 +47,7 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
}
template <>
void SoftmaxKernel<FPGA, float>::Compute(
const SoftmaxParam<FPGA> &param) {
void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) {
Tensor *in_x = param.FloatInput();
Tensor *out = param.Out();
......
......@@ -211,8 +211,7 @@ bool ConvKernel<GPU_MALI, float>::Init(ConvParam<GPU_MALI>* param) {
}
template <>
void ConvKernel<GPU_MALI, float>::Compute(
const ConvParam<GPU_MALI>& param) {
void ConvKernel<GPU_MALI, float>::Compute(const ConvParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclConvOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclConvOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -127,8 +127,7 @@ bool LrnKernel<GPU_MALI, float>::Init(LrnParam<GPU_MALI>* param) {
}
template <>
void LrnKernel<GPU_MALI, float>::Compute(
const LrnParam<GPU_MALI>& param) {
void LrnKernel<GPU_MALI, float>::Compute(const LrnParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclLrnOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclLrnOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -27,8 +27,7 @@ bool MulKernel<GPU_MALI, float>::Init(MulParam<GPU_MALI> *param) {
}
template <>
void MulKernel<GPU_MALI, float>::Compute(
const MulParam<GPU_MALI> &param) {
void MulKernel<GPU_MALI, float>::Compute(const MulParam<GPU_MALI> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *out = param.Out();
......
......@@ -195,8 +195,7 @@ bool PoolKernel<GPU_MALI, float>::Init(PoolParam<GPU_MALI>* param) {
}
template <>
void PoolKernel<GPU_MALI, float>::Compute(
const PoolParam<GPU_MALI>& param) {
void PoolKernel<GPU_MALI, float>::Compute(const PoolParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclPoolOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclPoolOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -115,8 +115,7 @@ bool ReluKernel<GPU_MALI, float>::Init(ReluParam<GPU_MALI>* param) {
}
template <>
void ReluKernel<GPU_MALI, float>::Compute(
const ReluParam<GPU_MALI>& param) {
void ReluKernel<GPU_MALI, float>::Compute(const ReluParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclReluOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclReluOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -389,6 +389,13 @@ class ConvParam : public OpParam {
const int &Groups() const { return groups; }
#ifdef PADDLE_MOBILE_CL
int Offset() const { return offset_; }
int SetOffset(int in_offset) { offset_ = in_offset; }
#endif
private:
RType *input_;
RType *output_;
......@@ -397,6 +404,10 @@ class ConvParam : public OpParam {
vector<int> paddings_;
vector<int> dilations_;
int groups;
#ifdef PADDLE_MOBILE_CL
int offset_;
#endif
};
template <typename Dtype>
Print &operator<<(Print &printer, const ConvParam<Dtype> &conv_param);
......@@ -1520,6 +1531,7 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
bool is_test_;
RType *new_bias_;
RType *new_scale_;
#ifdef PADDLE_MOBILE_FPGA
private:
......
......@@ -68,5 +68,8 @@ REGISTER_OPERATOR_MALI_GPU(pool2d, ops::PoolOp);
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(pool2d, ops::PoolOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(pool2d, ops::PoolOp);
#endif
#endif
......@@ -54,5 +54,8 @@ USE_OP_MALI_GPU(pool2d);
#ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(pool2d);
#endif
#ifdef PADDLE_MOBILE_CL
USE_OP_CL(pool2d);
#endif
#endif
......@@ -41,5 +41,8 @@ REGISTER_OPERATOR_MALI_GPU(relu, ops::ReluOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(relu, ops::ReluOp);
#endif
#endif
......@@ -57,5 +57,8 @@ USE_OP_MALI_GPU(relu);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
USE_OP_CL(relu);
#endif
#endif
......@@ -40,5 +40,8 @@ REGISTER_OPERATOR_MALI_GPU(reshape, ops::ReshapeOp);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(reshape, ops::ReshapeOp);
#endif
#endif
......@@ -56,5 +56,8 @@ USE_OP_MALI_GPU(reshape);
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#ifdef PADDLE_MOBILE_CL
USE_OP_CL(reshape);
#endif
#endif
......@@ -36,5 +36,8 @@ REGISTER_OPERATOR_MALI_GPU(softmax, ops::SoftmaxOp);
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(softmax, ops::SoftmaxOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(softmax, ops::SoftmaxOp);
#endif
#endif
......@@ -52,5 +52,8 @@ USE_OP_MALI_GPU(softmax);
#ifdef PADDLE_MOBILE_FPGA
USE_OP_FPGA(softmax);
#endif
#ifdef PADDLE_MOBILE_CL
USE_OP_CL(softmax);
#endif
#endif
......@@ -83,175 +83,175 @@ elseif("genet" IN_LIST NET)
target_link_libraries(test-genet paddle-mobile)
else ()
# gen test
ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-resnet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-squeezenet net/test_squeezenet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-squeezenet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-yolo net/test_yolo.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-yolo paddle-mobile)
# gen test
ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-googlenet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-mul-op operators/test_mul_op.cpp test_helper.h test_include.h)
target_link_libraries(test-mul-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-elementwiseadd-op operators/test_elementwise_add_op.cpp test_helper.h test_include.h)
target_link_libraries(test-elementwiseadd-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h)
target_link_libraries(test-concat-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-lrn-op operators/test_lrn_op.cpp test_helper.h test_include.h)
target_link_libraries(test-lrn-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-batchnorm-op operators/test_batchnorm_op.cpp test_helper.h test_include.h)
target_link_libraries(test-batchnorm-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-priorbox-op operators/test_prior_box_op.cpp test_helper.h test_include.h)
target_link_libraries(test-priorbox-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-boxcoder-op operators/test_box_coder_op.cpp test_helper.h test_include.h)
target_link_libraries(test-boxcoder-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-transpose-op operators/test_transpose_op.cpp test_helper.h test_include.h)
target_link_libraries(test-transpose-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h)
target_link_libraries(test-multiclassnms-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h)
target_link_libraries(test-reshape-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h)
target_link_libraries(test-relu-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h)
target_link_libraries(test-fc-op paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-log common/test_log.cpp)
target_link_libraries(test-log paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-load framework/test_load.cpp)
target_link_libraries(test-load paddle-mobile)
# gen test log
ADD_EXECUTABLE(test-loadmemory framework/test_load_memory.cpp)
target_link_libraries(test-loadmemory paddle-mobile)
ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp)
target_link_libraries(test-inference-api paddle-mobile)
# gen test log
# gen test
ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp)
target_link_libraries(test-optimize paddle-mobile)
#gen test
ADD_EXECUTABLE(test-pool operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-pool paddle-mobile)
#gen test
ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-softmax paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp)
target_link_libraries(test-gemm-accuracy paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp)
target_link_libraries(test-gemm-perf paddle-mobile)
# gen test
ADD_EXECUTABLE(test-enforce common/test_enforce.cpp)
target_link_libraries(test-enforce paddle-mobile)
# gen test - test if openmp works
ADD_EXECUTABLE(test-openmp common/test_openmp.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-openmp paddle-mobile)
# gen test
ADD_EXECUTABLE(test-mobilenetssd net/test_mobilenet+ssd.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-mobilenetssd paddle-mobile)
# gen test
ADD_EXECUTABLE(test-mobilenet-combine net/test_mobilenet_combine.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-mobilenet-combine paddle-mobile)
# gen test
ADD_EXECUTABLE(test-genet net/test_genet_combine.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-genet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h)
target_link_libraries(test-sigmoid paddle-mobile)
# gen test
ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-depthwise-conv-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-mobilenet net/test_mobilenet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-mobilenet paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-add-relu-op operators/test_conv_add_relu_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-add-relu-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-conv-add-bn-relu-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-nlp net/test_nlp.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-nlp paddle-mobile)
# gen test
ADD_EXECUTABLE(test-gru-op operators/test_gru_op.cpp test_helper.h test_include.h)
target_link_libraries(test-gru-op paddle-mobile)
# gen test
ADD_EXECUTABLE(test-inceptionv4 net/test_inceptionv4.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-inceptionv4 paddle-mobile)
# gen test
ADD_EXECUTABLE(test-alexnet net/test_alexnet.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-alexnet paddle-mobile)
ADD_EXECUTABLE(test-googlenetv1 net/test_googlenetv1_combine.cpp test_helper.h test_include.h)
target_link_libraries(test-googlenetv1 paddle-mobile)
# gen test
ADD_EXECUTABLE(test-fssd net/test_mobilenet_025_fssd.cpp test_helper.h test_include.h)
target_link_libraries(test-fssd paddle-mobile)
# # gen test
# ADD_EXECUTABLE(test-resnet net/test_resnet.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-resnet paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-squeezenet net/test_squeezenet.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-squeezenet paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-yolo net/test_yolo.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-yolo paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-googlenet net/test_googlenet.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-googlenet paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-conv-op operators/test_cov_op.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-conv-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-mul-op operators/test_mul_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-mul-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-elementwiseadd-op operators/test_elementwise_add_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-elementwiseadd-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-concat-op operators/test_concat_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-concat-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-lrn-op operators/test_lrn_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-lrn-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-batchnorm-op operators/test_batchnorm_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-batchnorm-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-priorbox-op operators/test_prior_box_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-priorbox-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-boxcoder-op operators/test_box_coder_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-boxcoder-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-transpose-op operators/test_transpose_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-transpose-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-multiclassnms-op operators/test_multiclass_nms_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-multiclassnms-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-reshape-op operators/test_reshape_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-reshape-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-relu-op operators/test_relu_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-relu-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-fc-op operators/test_fusion_fc_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-fc-op paddle-mobile)
#
# # gen test log
# ADD_EXECUTABLE(test-log common/test_log.cpp)
# target_link_libraries(test-log paddle-mobile)
#
# # gen test log
# ADD_EXECUTABLE(test-load framework/test_load.cpp)
# target_link_libraries(test-load paddle-mobile)
#
# # gen test log
# ADD_EXECUTABLE(test-loadmemory framework/test_load_memory.cpp)
# target_link_libraries(test-loadmemory paddle-mobile)
#
# ADD_EXECUTABLE(test-inference-api framework/test_inference_api.cpp)
# target_link_libraries(test-inference-api paddle-mobile)
#
#
# # gen test log
# # gen test
# ADD_EXECUTABLE(test-optimize framework/test_optimize.cpp)
# target_link_libraries(test-optimize paddle-mobile)
#
#
# #gen test
# ADD_EXECUTABLE(test-pool operators/test_pool_op.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-pool paddle-mobile)
#
# #gen test
# ADD_EXECUTABLE(test-softmax operators/test_softmax_op.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-softmax paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-gemm-accuracy common/test_gemm_accuracy.cpp)
# target_link_libraries(test-gemm-accuracy paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-gemm-perf common/test_gemm_perf.cpp)
# target_link_libraries(test-gemm-perf paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-enforce common/test_enforce.cpp)
# target_link_libraries(test-enforce paddle-mobile)
#
# # gen test - test if openmp works
# ADD_EXECUTABLE(test-openmp common/test_openmp.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-openmp paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-mobilenetssd net/test_mobilenet+ssd.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-mobilenetssd paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-mobilenet-combine net/test_mobilenet_combine.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-mobilenet-combine paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-genet net/test_genet_combine.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-genet paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-sigmoid operators/test_sigmoid_op.cpp test_include.h)
# target_link_libraries(test-sigmoid paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-depthwise-conv-op operators/test_depthwise_conv_op.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-depthwise-conv-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-mobilenet net/test_mobilenet.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-mobilenet paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-conv-add-relu-op operators/test_conv_add_relu_op.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-conv-add-relu-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-conv-add-bn-relu-op operators/test_fusion_conv_add_bn_relu_op.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-conv-add-bn-relu-op paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-nlp net/test_nlp.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-nlp paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-gru-op operators/test_gru_op.cpp test_helper.h test_include.h)
# target_link_libraries(test-gru-op paddle-mobile)
#
# # gen test
#
# ADD_EXECUTABLE(test-inceptionv4 net/test_inceptionv4.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-inceptionv4 paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-alexnet net/test_alexnet.cpp test_helper.h test_include.h executor_for_test.h)
# target_link_libraries(test-alexnet paddle-mobile)
#
# ADD_EXECUTABLE(test-googlenetv1 net/test_googlenetv1_combine.cpp test_helper.h test_include.h)
# target_link_libraries(test-googlenetv1 paddle-mobile)
#
# # gen test
# ADD_EXECUTABLE(test-fssd net/test_mobilenet_025_fssd.cpp test_helper.h test_include.h)
# target_link_libraries(test-fssd paddle-mobile)
# gen test
ADD_EXECUTABLE(test-mobilenetgpu net/test_mobilenet_GPU.cpp test_helper.h test_include.h)
......
......@@ -18,8 +18,8 @@ limitations under the License. */
#include <vector>
#include "common/log.h"
#include "framework/op_registry.h"
#include "framework/executor.h"
#include "framework/op_registry.h"
#include "operators/conv_op.h"
#include "operators/elementwise_add_op.h"
#include "operators/pool_op.h"
......@@ -29,9 +29,9 @@ limitations under the License. */
#include "operators/softmax_op.h"
#include "operators/transpose_op.h"
using paddle_mobile::framework::Executor;
using paddle_mobile::framework::BlockDesc;
using paddle_mobile::framework::DDim;
using paddle_mobile::framework::Executor;
using paddle_mobile::framework::LoDTensor;
using paddle_mobile::framework::OpDesc;
using paddle_mobile::framework::Program;
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_helper.h"
#include "framework/loader.h"
#include "framework/program/program-optimize/node.h"
#include "framework/program/program-optimize/program_optimize.h"
#include "framework/loader.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
......
......@@ -17,43 +17,43 @@ limitations under the License. */
#include "../test_include.h"
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::GPU_CL> paddle_mobile;
// paddle_mobile.SetThreadNum(4);
auto time1 = time();
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(g_mobilenet, false);
if (isok) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224_banana, &input, dims);
auto vec_result = paddle_mobile.Predict(input, dims);
std::vector<float>::iterator biggest =
std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl;
// 预热十次
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
DLOG << vec_result;
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
paddle_mobile::PaddleMobile<paddle_mobile::GPU_CL> paddle_mobile;
// paddle_mobile.SetThreadNum(4);
auto time1 = time();
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(g_mobilenet, false);
if (isok) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224_banana, &input, dims);
auto vec_result = paddle_mobile.Predict(input, dims);
std::vector<float>::iterator biggest =
std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl;
// 预热十次
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana "
"是否存在?"
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
DLOG << vec_result;
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
return 0;
}
std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana "
"是否存在?"
<< std::endl;
return 0;
}
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../../src/operators/kernel/sigmoid_kernel.h"
#include "../../src/operators/kernel/central-arm-func/sigmoid_arm_func.h"
#include "../../src/operators/kernel/sigmoid_kernel.h"
#include "../test_helper.h"
#include "framework/executor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册