未验证 提交 4d6fdc95 编写于 作者: R Ray Liu 提交者: GitHub

Merge pull request #1052 from codeWorm2015/opencl

add default cl work size func , format files
......@@ -46,7 +46,8 @@ struct PaddleMobileException : public std::exception {
std::string detail(buffer); \
throw paddle_mobile::PaddleMobileException("Custom Exception", buffer, \
__FILE__, __LINE__); \
}
} \
exit(0);
#define PADDLE_MOBILE_ENFORCE(stat, ...) \
{ \
......
......@@ -39,7 +39,13 @@ struct PrecisionTrait<Precision::FP16> {
};
//! device type
enum DeviceTypeEnum { kINVALID = -1, kCPU = 0, kFPGA = 1, kGPU_MALI = 2, kGPU_CL = 3};
enum DeviceTypeEnum {
kINVALID = -1,
kCPU = 0,
kFPGA = 1,
kGPU_MALI = 2,
kGPU_CL = 3
};
template <DeviceTypeEnum T>
struct DeviceType {};
......@@ -49,7 +55,6 @@ typedef DeviceType<kFPGA> FPGA;
typedef DeviceType<kGPU_MALI> GPU_MALI;
typedef DeviceType<kGPU_CL> GPU_CL;
//! data type
enum DataType {
PM_INVALID = -1,
......
......@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/cl/cl_engine.h"
#include "CL/cl.h"
#include "framework/cl/cl_tool.h"
#include "framework/cl/cl_engine.h"
#include <cstdlib>
#include <cstring>
......@@ -28,11 +28,11 @@ bool CLEngine::Init() {
SetClDeviceId();
initialized_ = true;
// setClContext();
// setClCommandQueue();
// std::string filename = "./HelloWorld_Kernel.cl";
// loadKernelFromFile(filename.c_str());
// buildProgram();
// setClContext();
// setClCommandQueue();
// std::string filename = "./HelloWorld_Kernel.cl";
// loadKernelFromFile(filename.c_str());
// buildProgram();
}
CLEngine *CLEngine::Instance() {
......@@ -74,26 +74,26 @@ bool CLEngine::SetClDeviceId() {
return false;
}
//std::unique_ptr<_cl_kernel, clKernel_deleter> CLEngine::GSetKernel(
// std::unique_ptr<_cl_kernel, clKernel_deleter> CLEngine::GSetKernel(
// const std::string &kernel_name) {
// std::unique_ptr<_cl_kernel, clKernel_deleter> kernel(
// clCreateKernel(program_.get(), kernel_name.c_str(), NULL));
// return std::move(kernel);
//}
//
//bool CLEngine::SetClCommandQueue() {
// bool CLEngine::SetClCommandQueue() {
// cl_int status;
// command_queue_.reset(
// clCreateCommandQueue(context_.get(), devices_[0], 0, &status));
// return true;
//}
//bool CLEngine::SetClContext() {
// bool CLEngine::SetClContext() {
// context_.reset(clCreateContext(NULL, 1, devices_, NULL, NULL, NULL));
// return true;
//}
//bool CLEngine::LoadKernelFromFile(const char *kernel_file) {
// bool CLEngine::LoadKernelFromFile(const char *kernel_file) {
// size_t size;
// char *str;
// std::fstream f(kernel_file, (std::fstream::in | std::fstream::binary));
......@@ -118,10 +118,10 @@ bool CLEngine::SetClDeviceId() {
// const char *source = str;
// size_t sourceSize[] = {strlen(source)};
// program_.reset(
// clCreateProgramWithSource(context_.get(), 1, &source, sourceSize, NULL));
// clCreateProgramWithSource(context_.get(), 1, &source, sourceSize,
// NULL));
// return true;
//}
} // namespace framework
} // namespace paddle_mobile
......@@ -17,9 +17,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include "CL/cl.h"
#include "common/enforce.h"
#include "framework/cl/cl_deleter.h"
#include "CL/cl.h"
namespace paddle_mobile {
namespace framework {
......@@ -36,16 +36,18 @@ class CLEngine {
return std::move(context_ptr);
}
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> CreateClCommandQueue() {
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter>
CreateClCommandQueue() {
cl_int status;
cl_command_queue queue = clCreateCommandQueue(context_.get(), devices_[0], 0, &status);
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_ptr(queue);
cl_command_queue queue =
clCreateCommandQueue(context_.get(), devices_[0], 0, &status);
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_ptr(
queue);
return std::move(command_queue_ptr);
}
std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith(cl_context context, std::string file_name) {
std::unique_ptr<_cl_program, CLProgramDeleter> CreateProgramWith(
cl_context context, std::string file_name) {
FILE *file = fopen(file_name.c_str(), "rb");
PADDLE_MOBILE_ENFORCE(file != nullptr, "can't open file: %s ",
filename.c_str());
......@@ -62,7 +64,8 @@ class CLEngine {
const char *source = data;
size_t sourceSize[] = {strlen(source)};
cl_program p = clCreateProgramWithSource(context, 1, &source, sourceSize, NULL);
cl_program p =
clCreateProgramWithSource(context, 1, &source, sourceSize, NULL);
std::unique_ptr<_cl_program, CLProgramDeleter> program_ptr(p);
return std::move(program_ptr);
}
......@@ -81,7 +84,6 @@ class CLEngine {
bool SetClDeviceId();
bool initialized_;
cl_platform_id platform_;
......@@ -94,14 +96,13 @@ class CLEngine {
std::unique_ptr<_cl_program, CLProgramDeleter> program_;
// bool SetClContext();
// bool SetClCommandQueue();
// bool SetClContext();
// bool LoadKernelFromFile(const char *kernel_file);
// bool SetClCommandQueue();
// bool BuildProgram();
// bool LoadKernelFromFile(const char *kernel_file);
// bool BuildProgram();
};
} // namespace framework
......
此差异已折叠。
......@@ -18,4 +18,4 @@ limitations under the License. */
typedef uint16_t half_t;
half_t float2half(float f);
float half2float(half_t h);
\ No newline at end of file
float half2float(half_t h);
......@@ -14,11 +14,13 @@ limitations under the License. */
#pragma once
#include <vector>
#include <string>
#include <type_traits>
#include <vector>
#include "framework/cl/cl_scope.h"
#include "framework/cl/cl_deleter.h"
#include "framework/cl/cl_image.h"
#include "framework/cl/cl_scope.h"
namespace paddle_mobile {
namespace framework {
......@@ -27,24 +29,38 @@ class CLHelper {
public:
CLHelper() = default;
CLHelper(CLScope *scope): scope_(scope) {
}
explicit CLHelper(CLScope *scope) : scope_(scope) {}
void AddKernel(const std::string &kernel_name, const std::string &file_name) {
auto kernel = scope_->GetKernel(kernel_name, file_name);
kernels.emplace_back(std::move(kernel));
}
cl_kernel KernelAt(const int index) {
return kernels[index].get();
}
cl_kernel KernelAt(const int index) { return kernels[index].get(); }
cl_command_queue CLCommandQueue() {
return scope_->CommandQueue();
}
cl_command_queue CLCommandQueue() { return scope_->CommandQueue(); }
cl_context CLContext() { return scope_->Context(); }
std::vector<size_t> DefaultWorkSize(const CLImage &image) {
// n c h w
auto image_dim = image.dims();
if (image_dim.size() == 4) {
auto n = image_dim[0];
auto h = image_dim[2];
auto w = image_dim[3];
auto image_width = image.ImageWidth();
auto work_size_0 = image_width / w;
auto work_size_1 = w;
auto work_size_2 = n * h;
cl_context CLContext() {
return scope_->Context();
return {work_size_0, work_size_1, work_size_2};
}
PADDLE_MOBILE_THROW_EXCEPTION("not support this dim, need imp");
}
private:
......@@ -52,5 +68,5 @@ class CLHelper {
std::vector<std::unique_ptr<_cl_kernel, CLKernelDeleter>> kernels;
};
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -14,10 +14,12 @@ limitations under the License. */
#pragma once
#include <vector>
#include "CL/cl.h"
#include "framework/cl/cl_half.h"
#include "framework/ddim.h"
#include "framework/tensor.h"
#include "CL/cl.h"
#include "cl_half.h"
namespace paddle_mobile {
namespace framework {
......@@ -27,113 +29,127 @@ class CLImage {
CLImage() = default;
void Init(cl_context context, float *tensorInput, DDim ddim) {
tensorDims_ = ddim;
cl_image_format cf = {
.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT
};
tensor_dims_ = ddim;
cl_image_format cf = {.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT};
// NCHW -> [W * (C+3)/4, H * N]
DLOG<<tensorDims_;
size_t N,C,H,W;
if(tensorDims_.size()==4){
N = tensorDims_[0];
if(N<0){
N = 1;
}
C = tensorDims_[1];
H = tensorDims_[2];
W = tensorDims_[3];
}else if(tensorDims_.size()==1){
DLOG << tensor_dims_;
size_t N, C, H, W;
if (tensor_dims_.size() == 4) {
N = tensor_dims_[0];
if (N < 0) {
N = 1;
C = tensorDims_[0];
H = 1;
W = 1;
}
C = tensor_dims_[1];
H = tensor_dims_[2];
W = tensor_dims_[3];
} else if (tensor_dims_.size() == 1) {
N = 1;
C = tensor_dims_[0];
H = 1;
W = 1;
}
DLOG<<"-------InitMemory-------";
DLOG << "-------InitMemory-------";
size_t width = W * ((C + 3) / 4);
size_t height = H * N;
std::unique_ptr<half_t[]> imageData{};
int count = 0;
int count = 0;
if (tensorInput != nullptr) {
imageData.reset(new half_t[width * height * 4]);
float *p = tensorInput;
size_t i0 = 0;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
size_t i1 = i0;
for (int h = 0; h < H; h++) {
size_t i2 = (i1<<2) + c % 4;
for (int w = 0; w < W; w++) {
if (i2 >= width * height * 4) {
printf("%d > %d ----> %d, %d, %d, %d --- %d, %d, %d\n", i2, width*height*4, n, c, h, w, i0, i1, i2);
}
assert(i2 < width * height * 4);
imageData[i2] = float2half(*p);
i2 += 4;
p++;
// count++;
// DLOG<<count;
}
i1 += width;
}
}
i0 += width * H;
}
float *p = tensorInput;
size_t i0 = 0;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
size_t i1 = i0;
for (int h = 0; h < H; h++) {
size_t i2 = (i1 << 2) + c % 4;
for (int w = 0; w < W; w++) {
if (i2 >= width * height * 4) {
printf("%d > %d ----> %d, %d, %d, %d --- %d, %d, %d\n", i2,
width * height * 4, n, c, h, w, i0, i1, i2);
}
assert(i2 < width * height * 4);
imageData[i2] = float2half(*p);
i2 += 4;
p++;
// count++;
// DLOG<<count;
}
i1 += width;
}
}
i0 += width * H;
}
}
DLOG<<"-------InitMemory-------";
DLOG << "-------InitMemory-------";
cl_int err;
cl_image_ = clCreateImage2D(
context, // cl_context context
CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, // cl_mem_flags flags
&cf, // const cl_image_format *image_format
width, // size_t image_width
height, // size_t image_height
0, // size_t image_row_pitch
reinterpret_cast<void*>(imageData.get()), // void *host_ptr
&err // cl_int *errcode_ret
);
context, // cl_context context
CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, // cl_mem_flags flags
&cf, // const cl_image_format *image_format
width, // size_t image_width
height, // size_t image_height
0, // size_t image_row_pitch
reinterpret_cast<void *>(imageData.get()), // void *host_ptr
&err);
if (err != CL_SUCCESS) {
// TODO: error handling
// TODO(HaiPeng): error handling
}
}
void Init(cl_context context, DDim ddim) {
Init(context, nullptr, ddim);
}
void Init(cl_context context, DDim ddim) { Init(context, nullptr, ddim); }
inline CLImage &Resize(const DDim &dims) {
tensorDims_ = dims;
tensor_dims_ = dims;
return *this;
}
const DDim &dims() const {
return tensorDims_;
}
const DDim &dims() const { return tensor_dims_; }
std::vector<size_t> DefaultWorkSize() {
return {};
}
std::vector<size_t> DefaultWorkSize() { return {}; }
cl_mem GetCLImage() const { return cl_image_; }
cl_mem GetCLImage() {
return cl_image_;
template <typename T>
T *data() const {
return reinterpret_cast<T *>(tensor_input_);
}
inline int64_t numel() const { return product(tensor_dims_); }
int ImageWidth() const { return image_width_; }
int ImageHeight() const { return image_height_; }
int CBlock() const { return c_block_; }
int WidthOfOneBlock() const { return width_of_one_block_; }
int HeightOfOneBlock() const { return height_of_one_block_; }
private:
bool initialized_ = false;
cl_mem cl_image_;
DDim tensorDims_;
int image_width_;
int width_of_one_block_;
int height_of_one_block_;
int image_height_;
int c_block_;
DDim tensor_dims_;
float *tensor_input_;
cl_context context_;
};
//void TensorToCLImage(Tensor *tensor, CLImage *image) {
// void TensorToCLImage(Tensor *tensor, CLImage *image) {
//
//}
//
//void CLImageToTensor(CLImage *image, Tensor *tensor) {
// void CLImageToTensor(CLImage *image, Tensor *tensor) {
//
//}
}
}
\ No newline at end of file
} // namespace framework
} // namespace paddle_mobile
......@@ -18,10 +18,10 @@ limitations under the License. */
#include <string>
#include <unordered_map>
#include "framework/cl/cl_tool.h"
#include "framework/cl/cl_engine.h"
#include "framework/cl/cl_deleter.h"
#include "CL/cl.h"
#include "framework/cl/cl_deleter.h"
#include "framework/cl/cl_engine.h"
#include "framework/cl/cl_tool.h"
namespace paddle_mobile {
namespace framework {
......@@ -35,19 +35,17 @@ class CLScope {
command_queue_ = engin->CreateClCommandQueue();
}
cl_command_queue CommandQueue() {
return command_queue_.get();
}
cl_command_queue CommandQueue() { return command_queue_.get(); }
std::unique_ptr<_cl_kernel, CLKernelDeleter> GetKernel(const std::string &kernel_name, const std::string &file_name) {
std::unique_ptr<_cl_kernel, CLKernelDeleter> GetKernel(
const std::string &kernel_name, const std::string &file_name) {
auto program = Program(file_name);
std::unique_ptr<_cl_kernel, CLKernelDeleter> kernel(clCreateKernel(program, kernel_name.c_str(), NULL));
std::unique_ptr<_cl_kernel, CLKernelDeleter> kernel(
clCreateKernel(program, kernel_name.c_str(), NULL));
return std::move(kernel);
}
cl_context Context() {
return context_.get();
}
cl_context Context() { return context_.get(); }
cl_program Program(const std::string &file_name) {
auto it = programs_.find(file_name);
......@@ -55,20 +53,23 @@ class CLScope {
return it->second.get();
}
auto program = CLEngine::Instance()->CreateProgramWith(context_.get(), file_name);
auto program =
CLEngine::Instance()->CreateProgramWith(context_.get(), file_name);
programs_[file_name] = std::move(program);
status_ = clBuildProgram(program.get(), 0, 0, 0, 0, 0);
status_ = clBuildProgram(program.get(), 0, 0, 0, 0, 0);
CL_CHECK_ERRORS(status_);
return program.get();
}
private:
cl_int status_;
cl_int status_;
std::unique_ptr<_cl_context, CLContextDeleter> context_;
std::unique_ptr<_cl_command_queue, CLCommQueueDeleter> command_queue_;
std::unordered_map<std::string, std::unique_ptr<_cl_program, CLProgramDeleter>> programs_;
std::unordered_map<std::string,
std::unique_ptr<_cl_program, CLProgramDeleter>>
programs_;
};
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -18,17 +18,17 @@ limitations under the License. */
#include <string>
#include <vector>
#include "framework/tensor_base.h"
#include "framework/cl/cl_engine.h"
#include "framework/cl/cl_deleter.h"
#include "CL/cl.h"
#include "framework/cl/cl_deleter.h"
#include "framework/cl/cl_engine.h"
#include "framework/tensor_base.h"
namespace paddle_mobile {
namespace framework {
class CLTensor : TensorBase {
public:
CLTensor(cl_context context) : context_(context) {}
explicit CLTensor(cl_context context) : context_(context) {}
/*! Resize the dimensions of the memory block. */
inline CLTensor &Resize(const DDim &dims) {
......@@ -84,7 +84,6 @@ class CLTensor : TensorBase {
}
private:
cl_context context_;
/*
......@@ -99,18 +98,15 @@ class CLTensor : TensorBase {
virtual void set_type(std::type_index type) = 0;
* */
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, void *input, std::type_index type, cl_context context)
: ptr_(clCreateBuffer(context,
CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, size,
reinterpret_cast<void *>(input), NULL)),
PlaceholderImpl(size_t size, void *input, std::type_index type,
cl_context context)
: ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
size, reinterpret_cast<void *>(input), NULL)),
size_(size),
type_(type) {
}
type_(type) {}
PlaceholderImpl(size_t size, std::type_index type, cl_context context)
: ptr_(clCreateBuffer(context,
CL_MEM_READ_WRITE, size, NULL, NULL)),
: ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)),
size_(size),
type_(type) {}
......@@ -128,9 +124,7 @@ class CLTensor : TensorBase {
/* the current type of memory */
std::type_index type_;
};
};
} // namespace framework
......
......@@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cl_tool.h"
#include "framework/cl/cl_tool.h"
namespace paddle_mobile {
namespace framework {
const char *opencl_error_to_str(cl_int error) {
#define CASE_CL_CONSTANT(NAME) case NAME: return #NAME;
#define CASE_CL_CONSTANT(NAME) \
case NAME: \
return #NAME;
// Suppose that no combinations are possible.
switch (error) {
CASE_CL_CONSTANT(CL_SUCCESS)
......@@ -78,5 +80,5 @@ const char *opencl_error_to_str(cl_int error) {
#undef CASE_CL_CONSTANT
}
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -19,16 +19,15 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
const char* opencl_error_to_str (cl_int error);
#define CL_CHECK_ERRORS(ERR) \
if(ERR != CL_SUCCESS) \
{ \
printf( \
"OpenCL error with code %s happened in file %s at line %d. Exiting.\n", \
opencl_error_to_str(ERR), __FILE__, __LINE__ \
); \
}
}
}
const char* opencl_error_to_str(cl_int error);
#define CL_CHECK_ERRORS(ERR) \
if (ERR != CL_SUCCESS) { \
printf( \
"OpenCL error with code %s happened in file %s at line %d. " \
"Exiting.\n", \
opencl_error_to_str(ERR), __FILE__, __LINE__); \
}
} // namespace framework
} // namespace paddle_mobile
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "executor.h"
#include "framework/executor.h"
#include <operators/math/gemm.h>
#include <algorithm>
#include <vector>
......@@ -265,7 +265,7 @@ void Executor<Dtype, P>::InitCombineMemory() {
char *origin_data;
if (program_.combined_params_buf && program_.combined_params_len) {
LOG(kLOG_INFO) << "use outter memory";
origin_data = (char *)program_.combined_params_buf;
origin_data = reinterpret_cast<char *>(program_.combined_params_buf);
} else {
LOG(kLOG_INFO) << " begin init combine memory";
origin_data = Get_binary_data(program_.para_path);
......@@ -666,12 +666,12 @@ void Executor<Dtype, P>::InjectVariable(const framework::Tensor &t,
g_feed_value->GetMutable<framework::LoDTensor>();
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) {
InjectVariable(t, "feed");
};
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
......@@ -687,14 +687,14 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
auto *output_tensor = framework::GetVarValue<framework::LoDTensor>(
out_keys[0], output_map, *(program_.scope));
return std::make_shared<framework::Tensor>(framework::Tensor(*output_tensor));
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From_To(int start, int end) {
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
end = end < 0 ? (int)ops.size() : end;
end = end < 0 ? static_cast<int>(ops.size()) : end;
PADDLE_MOBILE_ENFORCE(start >= 0 && start < end && end <= ops.size(),
"start or end parameter is wrong");
......@@ -715,17 +715,17 @@ void Executor<Dtype, P>::Predict_From_To(int start, int end) {
profile[i].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
}
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From(int start) {
Predict_From_To(start);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_To(int end) {
Predict_From_To(0, end);
};
}
#endif
#ifdef PADDLE_MOBILE_FPGA
......@@ -738,12 +738,12 @@ void Executor<Dtype, P>::InjectVariable(const framework::Tensor &t,
g_feed_value->GetMutable<framework::LoDTensor>();
feed_tensor->Resize(t.dims());
feed_tensor->ShareDataWith(t);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::FeedData(const framework::Tensor &t) {
InjectVariable(t, "feed");
};
}
template <typename Dtype, Precision P>
std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
......@@ -759,14 +759,14 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::FetchResult(int id) {
auto *output_tensor = framework::GetVarValue<framework::LoDTensor>(
out_keys[0], output_map, *(program_.scope));
return std::make_shared<framework::Tensor>(framework::Tensor(*output_tensor));
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From_To(int start, int end) {
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
end = end < 0 ? (int)ops.size() : end;
end = end < 0 ? static_cast<int>(ops.size()) : end;
PADDLE_MOBILE_ENFORCE(start >= 0 && start < end && end <= ops.size(),
"start or end parameter is wrong");
......@@ -787,120 +787,120 @@ void Executor<Dtype, P>::Predict_From_To(int start, int end) {
profile[i].runEnd = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
}
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_From(int start) {
Predict_From_To(start);
};
}
template <typename Dtype, Precision P>
void Executor<Dtype, P>::Predict_To(int end) {
Predict_From_To(0, end);
};
}
#endif
#ifdef PADDLE_MOBILE_CL
template <>
void Executor<GPU_CL, Precision::FP32>::LoadMemory(const framework::VarDesc var_desc,
float *tensorInput, char **data) {
// 1. version
uint32_t version = *reinterpret_cast<uint32_t *>(*data);
(*data) += sizeof(uint32_t);
// 2 Lod information
uint64_t *lod_level_ptr = new uint64_t();
memcpy(lod_level_ptr, (*data), sizeof(uint64_t));
uint64_t lod_level = *lod_level_ptr;
delete lod_level_ptr;
(*data) += sizeof(uint64_t);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size = *reinterpret_cast<uint64_t *>(*data);
(*data) += sizeof(uint64_t);
std::vector<size_t> tmp(size / sizeof(size_t));
for (int k = 0; k < tmp.size(); ++k) {
tmp[k] = *reinterpret_cast<size_t *>(*data);
(*data) += sizeof(size_t);
}
}
// 3. tensor version
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(*data);
(*data) += sizeof(uint32_t);
// 4. tensor desc
int32_t size = *reinterpret_cast<int32_t *>(*data);
(*data) += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
buf.get()[m] = (*data)[m];
}
(*data) += (sizeof(char) * size);
const framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
}
void *memory = nullptr;
// int type_size = 0;
// switch (desc.DataType()) {
// case framework::VARTYPE_TYPE_FP16:
// type_size = 2;
// break;
// case framework::VARTYPE_TYPE_FP32:
// type_size = 4;
// memory = tensor->mutable_data<float>();
// break;
// case framework::VARTYPE_TYPE_FP64:
// type_size = 8;
// break;
// case framework::VARTYPE_TYPE_INT32:
// memory = tensor->mutable_data<int32_t>();
// type_size = 4;
// break;
// case framework::VARTYPE_TYPE_INT64:
// type_size = 8;
// break;
// case framework::VARTYPE_TYPE_BOOL:
// type_size = 1;
// break;
// default:
// break;
// }
int type_size = 4;
memory = tensorInput;
if (program_.quantification) {
float min_value;
float max_value;
memcpy(&min_value, *data, sizeof(float));
memcpy(&max_value, *data + sizeof(float), sizeof(float));
*data += 2 * sizeof(float);
const float factor = (max_value - min_value) / 255.0;
uint8_t *uint8_data = reinterpret_cast<uint8_t *>(*data);
for (int k = 0; k < memory_size; ++k) {
static_cast<float *>(memory)[k] = uint8_data[k] * factor + min_value;
}
*data += (memory_size * sizeof(uint8_t));
} else {
for (int n = 0; n < memory_size; n++) {
float value;
memcpy(&value, *data + n * type_size, type_size);
if (value < 1e-30 && value > -1e-30) {
static_cast<float *>(memory)[n] = 0.0;
} else {
static_cast<float *>(memory)[n] = value;
}
}
(*data) += (sizeof(char) * memory_size * type_size);
}
}
void Executor<GPU_CL, Precision::FP32>::LoadMemory(
const framework::VarDesc var_desc, float *tensorInput, char **data) {
// 1. version
uint32_t version = *reinterpret_cast<uint32_t *>(*data);
(*data) += sizeof(uint32_t);
// 2 Lod information
uint64_t *lod_level_ptr = new uint64_t();
memcpy(lod_level_ptr, (*data), sizeof(uint64_t));
uint64_t lod_level = *lod_level_ptr;
delete lod_level_ptr;
(*data) += sizeof(uint64_t);
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size = *reinterpret_cast<uint64_t *>(*data);
(*data) += sizeof(uint64_t);
std::vector<size_t> tmp(size / sizeof(size_t));
for (int k = 0; k < tmp.size(); ++k) {
tmp[k] = *reinterpret_cast<size_t *>(*data);
(*data) += sizeof(size_t);
}
}
// 3. tensor version
uint32_t tensor_version = *reinterpret_cast<uint32_t *>(*data);
(*data) += sizeof(uint32_t);
// 4. tensor desc
int32_t size = *reinterpret_cast<int32_t *>(*data);
(*data) += sizeof(int32_t);
std::unique_ptr<char[]> buf(new char[size]);
for (int m = 0; m < size; ++m) {
buf.get()[m] = (*data)[m];
}
(*data) += (sizeof(char) * size);
const framework::TensorDesc &desc = var_desc.Tensor_desc();
int memory_size = 1;
for (auto l : desc.Dims()) {
memory_size *= l;
}
void *memory = nullptr;
// int type_size = 0;
// switch (desc.DataType()) {
// case framework::VARTYPE_TYPE_FP16:
// type_size = 2;
// break;
// case framework::VARTYPE_TYPE_FP32:
// type_size = 4;
// memory = tensor->mutable_data<float>();
// break;
// case framework::VARTYPE_TYPE_FP64:
// type_size = 8;
// break;
// case framework::VARTYPE_TYPE_INT32:
// memory = tensor->mutable_data<int32_t>();
// type_size = 4;
// break;
// case framework::VARTYPE_TYPE_INT64:
// type_size = 8;
// break;
// case framework::VARTYPE_TYPE_BOOL:
// type_size = 1;
// break;
// default:
// break;
// }
int type_size = 4;
memory = tensorInput;
if (program_.quantification) {
float min_value;
float max_value;
memcpy(&min_value, *data, sizeof(float));
memcpy(&max_value, *data + sizeof(float), sizeof(float));
*data += 2 * sizeof(float);
const float factor = (max_value - min_value) / 255.0;
uint8_t *uint8_data = reinterpret_cast<uint8_t *>(*data);
for (int k = 0; k < memory_size; ++k) {
static_cast<float *>(memory)[k] = uint8_data[k] * factor + min_value;
}
*data += (memory_size * sizeof(uint8_t));
} else {
for (int n = 0; n < memory_size; n++) {
float value;
memcpy(&value, *data + n * type_size, type_size);
if (value < 1e-30 && value > -1e-30) {
static_cast<float *>(memory)[n] = 0.0;
} else {
static_cast<float *>(memory)[n] = value;
}
}
(*data) += (sizeof(char) * memory_size * type_size);
}
}
template <>
void Executor<GPU_CL, Precision::FP32>::InitMemory() {
......@@ -914,37 +914,35 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
}
char *origin_data =
Get_binary_data(program_.model_path + "/" + var_desc->Name());
char *data = origin_data;
char *data = origin_data;
cl_context context = program_.scope->GetCLScpoe()->Context();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
int numel = 1;
for (auto l : desc.Dims()) {
numel *= l;
}
DLOG<<var_desc->Name();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
int numel = 1;
for (auto l : desc.Dims()) {
numel *= l;
}
DLOG << var_desc->Name();
float *tensorInput = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
LoadMemory(*var_desc,tensorInput,&data);
paddle_mobile::memory::Alloc(sizeof(float) * numel));
LoadMemory(*var_desc, tensorInput, &data);
framework::DDim ddim = framework::make_ddim(desc.Dims());
cl_image->Init(context, tensorInput, ddim);
delete origin_data;
paddle_mobile::memory::Free(tensorInput);
}else{
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
DLOG<<var_desc->Name();
cl_image->Init(context, ddim);
delete origin_data;
paddle_mobile::memory::Free(tensorInput);
} else {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
}
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = framework::make_ddim(desc.Dims());
DLOG << var_desc->Name();
cl_image->Init(context, ddim);
}
}
}
}
......@@ -955,13 +953,13 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
char *origin_data;
if (program_.combined_params_buf && program_.combined_params_len) {
LOG(kLOG_INFO) << "use outter memory";
origin_data = (char *)program_.combined_params_buf;
origin_data = reinterpret_cast<char *>(program_.combined_params_buf);
} else {
LOG(kLOG_INFO) << " begin init combine memory";
origin_data = Get_binary_data(program_.para_path);
}
PADDLE_MOBILE_ENFORCE(origin_data != nullptr, "origin_data==nullptr!!!");
float *data = (float *)origin_data;
float *data = reinterpret_cast<float *>(origin_data);
for (const auto &block : to_predict_program_->Blocks()) {
for (const auto &var_desc : block->Vars()) {
......@@ -981,12 +979,12 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
for (int i = 0; i < ddim.size(); i++) {
numel = numel * ddim[i];
}
float *tensorInput = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
LoadMemory(*var_desc,tensorInput,&origin_data);
float *tensorInput = static_cast<float *>(
paddle_mobile::memory::Alloc(sizeof(float) * numel));
LoadMemory(*var_desc, tensorInput, &origin_data);
cl_image->Init(context, tensorInput, ddim);
paddle_mobile::memory::Free(tensorInput);
}else{
paddle_mobile::memory::Free(tensorInput);
} else {
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
......
......@@ -35,7 +35,7 @@ using std::string;
namespace paddle_mobile {
namespace framework {
template<typename Dtype = CPU, Precision P = Precision::FP32>
template <typename Dtype = CPU, Precision P = Precision::FP32>
class Executor {
public:
typedef typename PrecisionTrait<P>::ptype Ptype;
......@@ -56,7 +56,7 @@ class Executor {
* @b to predict
* */
std::shared_ptr<framework::LoDTensor> PredictLod(
const framework::LoDTensor &t);
const framework::LoDTensor &t);
/*
* @b to predict with vector and dim
......@@ -73,8 +73,8 @@ class Executor {
void LoadMemory(const framework::VarDesc var_desc,
framework::LoDTensor *tensor, char **data);
void LoadMemory(const framework::VarDesc var_desc,
float * tensorInput, char **data);
void LoadMemory(const framework::VarDesc var_desc, float *tensorInput,
char **data);
void InitCombineMemory();
......@@ -86,8 +86,8 @@ class Executor {
int block_id);
std::map<framework::BlockDesc,
std::vector<std::shared_ptr<framework::OperatorBase<Dtype>>>>
ops_of_block_;
std::vector<std::shared_ptr<framework::OperatorBase<Dtype>>>>
ops_of_block_;
bool use_optimize_ = false;
bool loddable_ = false;
#ifdef PADDLE_EXECUTOR_MULTITHREAD
......@@ -107,15 +107,15 @@ class Executor {
#ifdef PADDLE_MOBILE_FPGA
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult(int id = -1);
void Predict_From_To(int start = 0, int end = -1);
void Predict_From(int start);
void Predict_To(int end);
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
std::shared_ptr<framework::Tensor> FetchResult(int id = -1);
void Predict_From_To(int start = 0, int end = -1);
void Predict_From(int start);
void Predict_To(int end);
#endif
};
}
} // namespace framework
} // namespace paddle_mobile
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "loader.h"
#include "framework/loader.h"
#include "framework/lod_tensor.h"
#include "framework/program/program-optimize/program_optimize.h"
......@@ -29,10 +29,10 @@ namespace framework {
* @param originProgramDesc
* @param scope
*/
template<typename Dtype, Precision P>
template <typename Dtype, Precision P>
void Loader<Dtype, P>::InitMemoryFromProgram(
std::shared_ptr<ProgramDesc> &originProgramDesc,
std::shared_ptr<Scope> &scope) {
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope) {
for (const auto &block : originProgramDesc.get()->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = scope.get()->Var(var_desc->Name());
......@@ -56,32 +56,32 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
}
#ifdef PADDLE_MOBILE_CL
template<>
void Loader<GPU_CL, Precision::FP32>::InitMemoryFromProgram(
std::shared_ptr<ProgramDesc> &originProgramDesc,
std::shared_ptr<Scope> &scope) {
for (const auto &block : originProgramDesc.get()->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = scope.get()->Var(var_desc->Name());
if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
if (var_desc->Persistable()) {
auto dim = var_desc->Tensor_desc().Dims();
// auto tensor = var->GetMutable<LoDTensor>();
auto cl_image = var->GetMutable<framework::CLImage>();
cl_image->Resize(make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
dim[0] = 1;
auto cl_image = var->GetMutable<framework::CLImage>();
cl_image->Resize(make_ddim(dim));
}
} else {
// TODO(codeWorm): some.
}
}
}
template <>
void Loader<GPU_CL, Precision::FP32>::InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope) {
for (const auto &block : originProgramDesc.get()->Blocks()) {
for (const auto &var_desc : block->Vars()) {
auto var = scope.get()->Var(var_desc->Name());
if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) {
if (var_desc->Persistable()) {
auto dim = var_desc->Tensor_desc().Dims();
// auto tensor = var->GetMutable<LoDTensor>();
auto cl_image = var->GetMutable<framework::CLImage>();
cl_image->Resize(make_ddim(dim));
} else {
auto dim = var_desc->Tensor_desc().Dims();
PADDLE_MOBILE_ENFORCE(dim.size() > 0, "dim size is 0");
dim[0] = 1;
auto cl_image = var->GetMutable<framework::CLImage>();
cl_image->Resize(make_ddim(dim));
}
} else {
// TODO(codeWorm): some.
}
}
}
}
#endif
/**
......@@ -93,14 +93,14 @@ void Loader<Dtype, P>::InitMemoryFromProgram(
* @param program
* @param originProgramDesc
*/
template<typename Dtype, Precision P>
template <typename Dtype, Precision P>
void FusionAndPrintInfos(
bool &optimize, bool &can_add_split, Program<Dtype, P> &program,
const std::shared_ptr<ProgramDesc> &originProgramDesc) {
bool optimize, bool can_add_split, const Program<Dtype, P> &program,
const std::shared_ptr<ProgramDesc> &originProgramDesc) {
if (optimize) {
ProgramOptimize program_optimize;
program.optimizeProgram =
program_optimize.FusionOptimize(originProgramDesc, can_add_split);
program_optimize.FusionOptimize(originProgramDesc, can_add_split);
}
if (optimize) {
program.optimizeProgram->Description("optimize: ");
......@@ -131,20 +131,22 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) {
return cur_len;
}
template<typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &dirname, bool optimize, bool quantification,
bool can_add_split) {
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(const std::string &dirname,
bool optimize,
bool quantification,
bool can_add_split) {
auto program = this->LoadProgram(dirname + "/__model__", optimize,
quantification, can_add_split);
program.model_path = dirname;
return program;
}
template<typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(
const std::string &model_path, const std::string &para_path, bool optimize,
bool quantification) {
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::Load(const std::string &model_path,
const std::string &para_path,
bool optimize,
bool quantification) {
auto program = this->LoadProgram(model_path, optimize, quantification);
program.para_path = para_path;
......@@ -153,10 +155,10 @@ const Program<Dtype, P> Loader<Dtype, P>::Load(
return program;
}
template<typename Dtype, Precision P>
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
const std::string &model_path, bool optimize, bool quantification,
bool can_add_split) {
const std::string &model_path, bool optimize, bool quantification,
bool can_add_split) {
std::string model_filename = model_path;
PaddleMobile__Framework__Proto__ProgramDesc *c_program;
uint8_t *buf = NULL;
......@@ -165,7 +167,7 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
PADDLE_MOBILE_ENFORCE(buf != NULL, "read from __model__ is null");
c_program = paddle_mobile__framework__proto__program_desc__unpack(
NULL, read_size, buf);
NULL, read_size, buf);
//
PADDLE_MOBILE_ENFORCE(c_program != NULL, "program is null");
//
......@@ -190,17 +192,17 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadProgram(
return program;
}
template<typename Dtype, Precision P>
template <typename Dtype, Precision P>
const Program<Dtype, P> Loader<Dtype, P>::LoadCombinedMemory(
size_t read_size, const uint8_t *buf, size_t combined_params_len,
const uint8_t *combined_params_buf, bool optimize, bool quantification) {
size_t read_size, const uint8_t *buf, size_t combined_params_len,
const uint8_t *combined_params_buf, bool optimize, bool quantification) {
bool can_add_split = false;
PaddleMobile__Framework__Proto__ProgramDesc *c_program;
PADDLE_MOBILE_ENFORCE(buf != nullptr, "read from __model__ is null");
c_program = paddle_mobile__framework__proto__program_desc__unpack(
nullptr, read_size, buf);
nullptr, read_size, buf);
//
PADDLE_MOBILE_ENFORCE(c_program != nullptr, "program is null");
//
......@@ -225,17 +227,13 @@ const Program<Dtype, P> Loader<Dtype, P>::LoadCombinedMemory(
return program;
}
template
class Loader<CPU, Precision::FP32>;
template class Loader<CPU, Precision::FP32>;
template
class Loader<FPGA, Precision::FP32>;
template class Loader<FPGA, Precision::FP32>;
template
class Loader<GPU_MALI, Precision::FP32>;
template class Loader<GPU_MALI, Precision::FP32>;
template
class Loader<GPU_CL, Precision::FP32>;
template class Loader<GPU_CL, Precision::FP32>;
}
} // namespace framework
} // namespace paddle_mobile
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "framework/program/program.h"
namespace paddle_mobile {
namespace framework{
namespace framework {
template <typename Dtype = CPU, Precision P = Precision::FP32>
class Loader {
......@@ -30,33 +30,36 @@ class Loader {
* @b 加载分开形式的 fluid 模型
* */
const Program<Dtype, P> Load(const std::string &dirname,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
/*
* @b load combine format fluid mode
* @b 加载结合在一起格式的模型
* */
const Program<Dtype, P> Load(const std::string &model_path,
const std::string &para_path,
bool optimize = false,
bool quantification = false);
const std::string &para_path,
bool optimize = false,
bool quantification = false);
const Program<Dtype, P> LoadCombinedMemory(
size_t model_len, const uint8_t *model_buf, size_t combined_params_len,
const uint8_t *combined_params_buf, bool optimize = false,
bool quantification = false);
const Program<Dtype, P> LoadCombinedMemory(size_t model_len,
const uint8_t *model_buf,
size_t combined_params_len,
const uint8_t *combined_params_buf,
bool optimize = false,
bool quantification = false);
private:
const Program<Dtype, P> LoadProgram(const std::string &model_path,
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
bool optimize = false,
bool quantification = false,
bool can_add_split = false);
void InitMemoryFromProgram(std::shared_ptr<ProgramDesc> &originProgramDesc,
std::shared_ptr<Scope> &scope);
void InitMemoryFromProgram(
const std::shared_ptr<ProgramDesc> &originProgramDesc,
const std::shared_ptr<Scope> &scope);
};
}
} // namespace framework
} // namespace paddle_mobile
......@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once
#include <string>
#include <memory>
#include <string>
#include <tuple>
#include "common/log.h"
......@@ -92,7 +92,6 @@ class OpRegistry {
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap attrs,
std::shared_ptr<paddle_mobile::framework::Scope> scope) {
auto& info = OpInfoMap<Dtype>::Instance()->Get(type);
auto op = info.Creator()(type, inputs, outputs, attrs, scope);
return std::shared_ptr<OperatorBase<Dtype>>(op);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "common/enforce.h"
......@@ -32,8 +33,8 @@ limitations under the License. */
#include "framework/tensor.h"
#include "framework/variable.h"
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_scope.h"
#include "framework/cl/cl_helper.h"
#include "framework/cl/cl_scope.h"
#endif
namespace paddle_mobile {
namespace framework {
......@@ -131,7 +132,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
// DLOG << i.second;
// }
PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed",
this->type_.c_str());
}
......@@ -147,7 +147,6 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
template <typename Dtype, typename P>
class OpKernelBase {
public:
OpKernelBase() = default;
#ifdef PADDLE_MOBILE_CL
......@@ -156,11 +155,11 @@ class OpKernelBase {
}
#endif
/*
* @b 所有kernel 需实现 Compute 方法
* @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
* 所有结构体存在与: paddle-mobile/src/operators/op_param.h
* */
/*
* @b 所有kernel 需实现 Compute 方法
* @p para 这个参数为 kernel 运算时所需要用到参数组成的一个结构体,
* 所有结构体存在与: paddle-mobile/src/operators/op_param.h
* */
#ifdef PADDLE_McOBILE_MALI_GPU
OpKernelBase() { acl_op_ = nullptr; }
void *GetAclOp() const { return acl_op_; }
......@@ -181,8 +180,6 @@ class OpKernelBase {
#ifdef PADDLE_MOBILE_MALI_GPU
void *acl_op_;
#endif
};
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \
......
......@@ -15,13 +15,14 @@ limitations under the License. */
#pragma once
#include <list>
#include <string>
#include <unordered_map>
#include <vector>
#ifdef PADDLE_MOBILE_CL
#ifdef PADDLE_MOBILE_CL
#include "framework/cl/cl_scope.h"
#endif
#include <unordered_map>
#include "variable.h"
#include "framework/variable.h"
namespace paddle_mobile {
namespace framework {
......@@ -42,7 +43,6 @@ class Scope {
#ifdef PADDLE_MOBILE_CL
delete cl_scope_;
#endif
}
Scope &NewScope() const;
......@@ -83,9 +83,7 @@ class Scope {
Variable *FindVarLocally(const std::string &name) const;
#ifdef PADDLE_MOBILE_CL
CLScope *GetCLScpoe() {
return cl_scope_;
}
CLScope *GetCLScpoe() { return cl_scope_; }
#endif
private:
......@@ -99,7 +97,6 @@ class Scope {
#ifdef PADDLE_MOBILE_CL
CLScope *cl_scope_ = new CLScope();
#endif
};
} // namespace framework
} // namespace paddle_mobile
......@@ -22,9 +22,9 @@ limitations under the License. */
#endif // _OPENMP
#include "common/types.h"
#include "framework/tensor.h"
#include "framework/executor.h"
#include "framework/loader.h"
#include "framework/tensor.h"
namespace paddle_mobile {
......@@ -94,6 +94,7 @@ class PaddleMobile {
std::shared_ptr<framework::Executor<Dtype, P>> executor_;
#ifdef PADDLE_MOBILE_FPGA
public:
void InjectVariable(const framework::Tensor &t, string var_name);
void FeedData(const framework::Tensor &t);
......
......@@ -43,13 +43,14 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
#ifdef PADDLE_MOBILE_FPGA
void Init() {
void Init() {
Tensor *output = param_.Out();
fpga::format_fp16_ofm(output);
}
void RunImpl() const {
auto input = (Tensor *)const_cast<LoDTensor *>(param_.InputX());
auto input =
reinterpret_cast<Tensor *>(const_cast<LoDTensor *>(param_.InputX()));
auto input_ptr = input->data<float>();
fpga::format_image(input);
Tensor *output = param_.Out();
......@@ -61,7 +62,7 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
args.output_data_type = fpga::DATA_TYPE_FP16;
args.input_layout_type = fpga::LAYOUT_CHW;
args.output_layout_type = fpga::LAYOUT_HWC;
args.image.address = (void *)input_ptr;
args.image.address = reinterpret_cast<void *>(input_ptr);
args.image.channels = (uint32_t)input->dims()[1];
args.image.height = (uint32_t)input->dims()[2];
args.image.width = (uint32_t)input->dims()[3];
......@@ -74,12 +75,10 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
#else
#ifdef PADDLE_MOBILE_CL
void Init() {}
void RunImpl() {
}
void Init() {}
void RunImpl() {}
#else
void Init() {}
void Init() {}
void RunImpl() {
param_.Out()->ShareDataWith(*param_.InputX());
param_.Out()->set_lod(param_.InputX()->lod());
......
......@@ -20,8 +20,8 @@ limitations under the License. */
#include <vector>
#include "framework/operator.h"
#include "framework/program/program-optimize/fusion_op_register.h"
#include "op_param.h"
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
......@@ -103,7 +103,7 @@ static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
#ifdef PADDLE_MOBILE_CL
#ifndef FUSION_CONV_ADD_BN_RELU_REGISTER
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
static framework::FusionOpRegistrar fusion_conv_add_bn_relu_registrar(
new FusionConvAddBNReluMatcher());
#define FUSION_CONV_ADD_BN_RELU_REGISTER
#endif
......
......@@ -26,8 +26,7 @@ bool BatchNormKernel<CPU, float>::Init(BatchNormParam<CPU> *param) {
}
template <>
void BatchNormKernel<CPU, float>::Compute(
const BatchNormParam<CPU> &param) {
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam<CPU> &param) {
BatchnormCompute<float>(param);
}
......
......@@ -26,8 +26,7 @@ bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam<CPU> *param) {
}
template <>
void BoxCoderKernel<CPU, float>::Compute(
const BoxCoderParam<CPU> &param) {
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam<CPU> &param) {
BoxCoderCompute<float>(param);
}
......
......@@ -25,8 +25,7 @@ bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
}
template <>
void ConvAddKernel<CPU, float>::Compute(
const FusionConvAddParam<CPU> &param) {
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam<CPU> &param) {
ConvAddCompute<float>(param);
}
......
......@@ -26,8 +26,7 @@ bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
}
template <>
void DepthwiseConvKernel<CPU, float>::Compute(
const ConvParam<CPU> &param) {
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) {
DepthwiseConvCompute<float>(param);
}
......
......@@ -26,8 +26,7 @@ bool FusionFcKernel<CPU, float>::Init(FusionFcParam<CPU> *param) {
}
template <>
void FusionFcKernel<CPU, float>::Compute(
const FusionFcParam<CPU> &param) {
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam<CPU> &param) {
FusionFcCompute<float>(param);
param.Out()->set_lod(param.InputX()->lod());
}
......
......@@ -26,8 +26,7 @@ bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam<CPU> *param) {
}
template <>
void PriorBoxKernel<CPU, float>::Compute(
const PriorBoxParam<CPU> &param) {
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam<CPU> &param) {
PriorBoxCompute<float>(param);
}
......
......@@ -25,8 +25,7 @@ bool TransposeKernel<CPU, float>::Init(TransposeParam<CPU> *param) {
}
template <>
void TransposeKernel<CPU, float>::Compute(
const TransposeParam<CPU> &param) {
void TransposeKernel<CPU, float>::Compute(const TransposeParam<CPU> &param) {
TransposeCompute<float>(param);
}
......
......@@ -33,4 +33,3 @@ inline hafl4 activation(half4 in
}
*/
......@@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
conv
conv_bn
......@@ -30,7 +27,6 @@ conv_add_bn_relu
#include "common.h"
__kernel void conv_1x1(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
......
......@@ -21,14 +21,14 @@ namespace operators {
template <>
bool ConvAddBNReluKernel<GPU_CL, float>::Init(
FusionConvAddBNReluParam<GPU_CL> *param) {
FusionConvAddBNReluParam<GPU_CL> *param) {
return true;
}
template <>
void ConvAddBNReluKernel<GPU_CL, float>::Compute(
const FusionConvAddBNReluParam<GPU_CL> &param) {
}
const FusionConvAddBNReluParam<GPU_CL> &param) {}
template class ConvAddBNReluKernel<GPU_CL, float>;
} // namespace operators
......
......@@ -15,20 +15,103 @@ limitations under the License. */
#ifdef FUSION_CONVADDBNRELU_OP
#include "operators/kernel/conv_add_bn_relu_kernel.h"
#include "framework/cl/cl_image.h"
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNReluKernel<GPU_CL, float>::Init(
FusionConvAddBNReluParam<GPU_CL> *param) {
FusionConvAddBNReluParam<GPU_CL> *param) {
// const CL *mean = param->InputMean();
const framework::CLImage *mean = param->InputMean();
const framework::CLImage *variance = param->InputVariance();
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
auto scale_ptr = scale->data<float>();
auto bias_ptr = bias->data<float>();
const int C = mean->numel();
float inv_std_ptr[C];
for (int i = 0; i < C; i++) {
inv_std_ptr[i] =
1 / static_cast<float>(pow((variance_ptr[i] + epsilon), 0.5));
}
float *new_scale_ptr = new float[C];
float *new_bias_ptr = new float[C];
for (int i = 0; i < C; i++) {
new_scale_ptr[i] = inv_std_ptr[i] * scale_ptr[i];
new_bias_ptr[i] = bias_ptr[i] - mean_ptr[i] * inv_std_ptr[i] * scale_ptr[i];
}
delete[](new_scale_ptr);
delete[](new_bias_ptr);
framework::CLImage *new_scale = new framework::CLImage();
framework::CLImage *new_bias = new framework::CLImage();
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
PADDLE_MOBILE_ENFORCE(
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
param->SetOffset(param->Filter()->dims()[2] / 2 -
static_cast<int>(param->Paddings()[1]));
this->cl_helper_.AddKernel("conv_3x3", "conv_add_bn_relu_kernel.cl");
return true;
}
template <>
void ConvAddBNReluKernel<GPU_CL, float>::Compute(
const FusionConvAddBNReluParam<GPU_CL> &param) {
const FusionConvAddBNReluParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Output());
int c_block = default_work_size[0];
int w = default_work_size[1];
int nh = default_work_size[2];
auto input = param.Input()->GetCLImage();
auto filter = param.Filter()->GetCLImage();
auto biase = param.Bias()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = param.Input()->CBlock();
int input_width = param.Input()->WidthOfOneBlock();
int input_height = param.Input()->HeightOfOneBlock();
clSetKernelArg(kernel, 0, sizeof(int), &c_block);
clSetKernelArg(kernel, 1, sizeof(int), &w);
clSetKernelArg(kernel, 2, sizeof(int), &nh);
clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase);
clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_scale);
clSetKernelArg(kernel, 7, sizeof(cl_mem), &new_bias);
clSetKernelArg(kernel, 8, sizeof(cl_mem), &output);
clSetKernelArg(kernel, 9, sizeof(int), &stride);
clSetKernelArg(kernel, 10, sizeof(int), &offset);
clSetKernelArg(kernel, 11, sizeof(int), &input_c);
clSetKernelArg(kernel, 12, sizeof(int), &input_width);
clSetKernelArg(kernel, 13, sizeof(int), &input_height);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
}
template class ConvAddBNReluKernel<GPU_CL, float>;
} // namespace operators
......
......@@ -26,8 +26,7 @@ bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) {
template <>
void ConvAddKernel<GPU_CL, float>::Compute(
const FusionConvAddParam<GPU_CL> &param) {
}
const FusionConvAddParam<GPU_CL> &param) {}
template class ConvAddKernel<GPU_CL, float>;
......
......@@ -21,15 +21,16 @@ namespace operators {
template <>
bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
// this->cl_helper_.AddKernel("conv_3x3", "conv_kernel.cl");
// this->cl_helper_.AddKernel("conv_3x3", "conv_kernel.cl");
return true;
}
template <>
void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) {
// auto kernel = this->cl_helper_.KernelAt(0);
// size_t global_work_size[3] = {1, 2, 3};
// clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, global_work_size, NULL, 0, NULL, NULL);
// auto kernel = this->cl_helper_.KernelAt(0);
// size_t global_work_size[3] = {1, 2, 3};
// clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
// global_work_size, NULL, 0, NULL, NULL);
}
template class ConvKernel<GPU_CL, float>;
......
......@@ -17,22 +17,23 @@ limitations under the License. */
#include "operators/kernel/elementwise_add_kernel.h"
namespace paddle_mobile {
namespace operators {
namespace operators {
template <>
bool ElementwiseAddKernel<GPU_CL, float>::Init(ElementwiseAddParam<GPU_CL> *param) {
// this->cl_helper_.AddKernel("elementwise_add", "elementwise_add_kernel.cl");
return true;
}
template <>
bool ElementwiseAddKernel<GPU_CL, float>::Init(
ElementwiseAddParam<GPU_CL> *param) {
// this->cl_helper_.AddKernel("elementwise_add",
// "elementwise_add_kernel.cl");
return true;
}
template <>
void ElementwiseAddKernel<GPU_CL, float>::Compute(const ElementwiseAddParam<GPU_CL> &param) {
template <>
void ElementwiseAddKernel<GPU_CL, float>::Compute(
const ElementwiseAddParam<GPU_CL> &param) {}
}
template class ElementwiseAddKernel<GPU_CL, float>;
template class ElementwiseAddKernel<GPU_CL, float>;
} // namespace operators
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/relu_kernel.h"
namespace paddle_mobile {
......@@ -30,4 +29,3 @@ template class ReluKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/kernel/reshape_kernel.h"
namespace paddle_mobile {
......@@ -30,4 +29,3 @@ template class ReshapeKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
......@@ -17,20 +17,18 @@ limitations under the License. */
#include "operators/kernel/softmax_kernel.h"
namespace paddle_mobile {
namespace operators {
namespace operators {
template <>
bool SoftmaxKernel<GPU_CL, float>::Init(SoftmaxParam<GPU_CL> *param) {
return true;
}
template <>
bool SoftmaxKernel<GPU_CL, float>::Init(SoftmaxParam<GPU_CL> *param) {
return true;
}
template <>
void SoftmaxKernel<GPU_CL, float>::Compute(const SoftmaxParam<GPU_CL> &param) {}
template <>
void SoftmaxKernel<GPU_CL, float>::Compute(const SoftmaxParam<GPU_CL> &param) {}
template class SoftmaxKernel<GPU_CL, float>;
template class SoftmaxKernel<GPU_CL, float>;
} // namespace operators
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -67,8 +67,7 @@ bool ConvBNKernel<FPGA, float>::Init(FusionConvBNParam<FPGA> *param) {
}
template <>
void ConvBNKernel<FPGA, float>::Compute(
const FusionConvBNParam<FPGA> &param) {
void ConvBNKernel<FPGA, float>::Compute(const FusionConvBNParam<FPGA> &param) {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
......
......@@ -26,8 +26,7 @@ bool DropoutKernel<FPGA, float>::Init(DropoutParam<FPGA> *param) {
}
template <>
void DropoutKernel<FPGA, float>::Compute(
const DropoutParam<FPGA> &param) {}
void DropoutKernel<FPGA, float>::Compute(const DropoutParam<FPGA> &param) {}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -60,8 +60,7 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
}
template <>
void FusionFcKernel<FPGA, float>::Compute(
const FusionFcParam<FPGA> &param) {
void FusionFcKernel<FPGA, float>::Compute(const FusionFcParam<FPGA> &param) {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
......
......@@ -47,8 +47,7 @@ bool SoftmaxKernel<FPGA, float>::Init(SoftmaxParam<FPGA> *param) {
}
template <>
void SoftmaxKernel<FPGA, float>::Compute(
const SoftmaxParam<FPGA> &param) {
void SoftmaxKernel<FPGA, float>::Compute(const SoftmaxParam<FPGA> &param) {
Tensor *in_x = param.FloatInput();
Tensor *out = param.Out();
......
......@@ -211,8 +211,7 @@ bool ConvKernel<GPU_MALI, float>::Init(ConvParam<GPU_MALI>* param) {
}
template <>
void ConvKernel<GPU_MALI, float>::Compute(
const ConvParam<GPU_MALI>& param) {
void ConvKernel<GPU_MALI, float>::Compute(const ConvParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclConvOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclConvOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -127,8 +127,7 @@ bool LrnKernel<GPU_MALI, float>::Init(LrnParam<GPU_MALI>* param) {
}
template <>
void LrnKernel<GPU_MALI, float>::Compute(
const LrnParam<GPU_MALI>& param) {
void LrnKernel<GPU_MALI, float>::Compute(const LrnParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclLrnOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclLrnOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -27,8 +27,7 @@ bool MulKernel<GPU_MALI, float>::Init(MulParam<GPU_MALI> *param) {
}
template <>
void MulKernel<GPU_MALI, float>::Compute(
const MulParam<GPU_MALI> &param) {
void MulKernel<GPU_MALI, float>::Compute(const MulParam<GPU_MALI> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *out = param.Out();
......
......@@ -195,8 +195,7 @@ bool PoolKernel<GPU_MALI, float>::Init(PoolParam<GPU_MALI>* param) {
}
template <>
void PoolKernel<GPU_MALI, float>::Compute(
const PoolParam<GPU_MALI>& param) {
void PoolKernel<GPU_MALI, float>::Compute(const PoolParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclPoolOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclPoolOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -115,8 +115,7 @@ bool ReluKernel<GPU_MALI, float>::Init(ReluParam<GPU_MALI>* param) {
}
template <>
void ReluKernel<GPU_MALI, float>::Compute(
const ReluParam<GPU_MALI>& param) {
void ReluKernel<GPU_MALI, float>::Compute(const ReluParam<GPU_MALI>& param) {
std::cout << "init acl" << std::endl;
AclReluOp<GPU_MALI, float>* acl_op =
reinterpret_cast<AclReluOp<GPU_MALI, float>*>(this->GetAclOp());
......
......@@ -389,6 +389,13 @@ class ConvParam : public OpParam {
const int &Groups() const { return groups; }
#ifdef PADDLE_MOBILE_CL
int Offset() const { return offset_; }
int SetOffset(int in_offset) { offset_ = in_offset; }
#endif
private:
RType *input_;
RType *output_;
......@@ -397,6 +404,10 @@ class ConvParam : public OpParam {
vector<int> paddings_;
vector<int> dilations_;
int groups;
#ifdef PADDLE_MOBILE_CL
int offset_;
#endif
};
template <typename Dtype>
Print &operator<<(Print &printer, const ConvParam<Dtype> &conv_param);
......@@ -1520,6 +1531,7 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
bool is_test_;
RType *new_bias_;
RType *new_scale_;
#ifdef PADDLE_MOBILE_FPGA
private:
......
......@@ -18,8 +18,8 @@ limitations under the License. */
#include <vector>
#include "common/log.h"
#include "framework/op_registry.h"
#include "framework/executor.h"
#include "framework/op_registry.h"
#include "operators/conv_op.h"
#include "operators/elementwise_add_op.h"
#include "operators/pool_op.h"
......@@ -29,9 +29,9 @@ limitations under the License. */
#include "operators/softmax_op.h"
#include "operators/transpose_op.h"
using paddle_mobile::framework::Executor;
using paddle_mobile::framework::BlockDesc;
using paddle_mobile::framework::DDim;
using paddle_mobile::framework::Executor;
using paddle_mobile::framework::LoDTensor;
using paddle_mobile::framework::OpDesc;
using paddle_mobile::framework::Program;
......
......@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "../test_helper.h"
#include "framework/loader.h"
#include "framework/program/program-optimize/node.h"
#include "framework/program/program-optimize/program_optimize.h"
#include "framework/loader.h"
int main() {
paddle_mobile::framework::Loader<paddle_mobile::CPU> loader;
......
......@@ -17,43 +17,43 @@ limitations under the License. */
#include "../test_include.h"
int main() {
paddle_mobile::PaddleMobile<paddle_mobile::GPU_CL> paddle_mobile;
// paddle_mobile.SetThreadNum(4);
auto time1 = time();
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(g_mobilenet, false);
if (isok) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224_banana, &input, dims);
auto vec_result = paddle_mobile.Predict(input, dims);
std::vector<float>::iterator biggest =
std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl;
// 预热十次
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
DLOG << vec_result;
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
paddle_mobile::PaddleMobile<paddle_mobile::GPU_CL> paddle_mobile;
// paddle_mobile.SetThreadNum(4);
auto time1 = time();
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(g_mobilenet, false);
if (isok) {
auto time2 = time();
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
std::vector<float> input;
std::vector<int64_t> dims{1, 3, 224, 224};
GetInput<float>(g_test_image_1x3x224x224_banana, &input, dims);
auto vec_result = paddle_mobile.Predict(input, dims);
std::vector<float>::iterator biggest =
std::max_element(std::begin(vec_result), std::end(vec_result));
std::cout << " Max element is " << *biggest << " at position "
<< std::distance(std::begin(vec_result), biggest) << std::endl;
// 预热十次
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana "
"是否存在?"
auto time3 = time();
for (int i = 0; i < 10; ++i) {
auto vec_result = paddle_mobile.Predict(input, dims);
}
DLOG << vec_result;
auto time4 = time();
std::cout << "predict cost :" << time_diff(time3, time4) / 10 << "ms"
<< std::endl;
return 0;
}
std::cout << "如果结果Nan请查看: test/images/g_test_image_1x3x224x224_banana "
"是否存在?"
<< std::endl;
return 0;
}
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "../../src/operators/kernel/sigmoid_kernel.h"
#include "../../src/operators/kernel/central-arm-func/sigmoid_arm_func.h"
#include "../../src/operators/kernel/sigmoid_kernel.h"
#include "../test_helper.h"
#include "framework/executor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册