未验证 提交 5f227934 编写于 作者: J Jiaying Zhao 提交者: GitHub

add pre and post process in feed and fetch kernel test=develop (#2157)

上级 69da22ec
......@@ -87,6 +87,11 @@ enum PMStatus {
PMException = 0x09 /*!< throw exception. */
};
enum PrePostType {
NONE_PRE_POST = 0,
UINT8_255 = 1,
};
enum RoundType {
ROUND_NEAREST_AWAY_ZERO = 0,
ROUND_NEAREST_TOWARDS_ZERO = 1,
......@@ -143,6 +148,7 @@ struct PaddleMobileConfigInternal {
MemoryOptimizationLevel memory_optimization_level =
MemoryOptimizationWithoutFeeds;
std::string model_obfuscate_key = "";
PrePostType pre_post_type = NONE_PRE_POST;
};
enum ARMArch {
......
......@@ -112,6 +112,9 @@ Executor<Device, T>::Executor(const Program<Device> &program,
profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
DLOG << "Initialize op[" << count++ << "]: " << op_handler->Type();
if (op_handler->Type() == "feed" || op_handler->Type() == "fetch") {
op_handler->setPrePostType(config_.pre_post_type);
}
op_handler->Init();
#ifdef PADDLE_MOBILE_PROFILE
clock_gettime(CLOCK_MONOTONIC, &ts);
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
......@@ -73,6 +74,7 @@ class OperatorBase {
const VariableNameMap &Outputs() const { return outputs_; }
const std::string &Type() const { return type_; }
const AttributeMap &Attrs() const { return attrs_; }
void setPrePostType(int prePostType) { pre_post_type_ = prePostType; }
void ClearVariables(const std::vector<std::string> &var_names) const {
if (this->scope_) {
......@@ -89,6 +91,7 @@ class OperatorBase {
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
int pre_post_type_ = 0;
private:
void CheckAllInputOutputSet() const;
......@@ -111,6 +114,9 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
virtual void InferShape() const = 0;
void Init() {
if (this->pre_post_type_ != NONE_PRE_POST) {
kernel_.setPrePostType(this->pre_post_type_);
}
PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), " %s kernel init failed",
this->type_.c_str());
}
......@@ -134,11 +140,13 @@ class OpKernelBase {
virtual void Compute(const P &para) = 0;
virtual bool Init(P *para) { return true; }
virtual ~OpKernelBase() = default;
virtual void setPrePostType(int prePostType) { pre_post_type_ = prePostType; }
protected:
#ifdef PADDLE_MOBILE_CL
CLHelper cl_helper_;
#endif
int pre_post_type_ = 0;
private:
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include "common/enforce.h"
#include "common/type_define.h"
#include "common/types.h"
......@@ -55,8 +56,8 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
};
static inline size_t SizeOfType(const kTypeId_t type) {
SizeOfTypeFunctor<int8_t, int, half, float, double, int16_t, int64_t, bool,
size_t>
SizeOfTypeFunctor<int8_t, uint8_t, int, half, float, double, int16_t, int64_t,
bool, size_t>
functor;
size_t size = functor(type);
......
......@@ -38,6 +38,9 @@ template <typename Device, typename T>
bool PaddleMobilePredictor<Device, T>::Init(const PaddleMobileConfig &config) {
PaddleMobileConfigInternal configInternal;
configInternal.load_when_predict = config.load_when_predict;
if (config.pre_post_type == PaddleMobileConfig::UINT8_255) {
configInternal.pre_post_type = PrePostType::UINT8_255;
}
paddle_mobile_.reset(new PaddleMobile<Device, T>(configInternal));
#ifdef PADDLE_MOBILE_CL
paddle_mobile_->SetCLPath(config.cl_path);
......@@ -86,26 +89,33 @@ bool PaddleMobilePredictor<Device, T>::Run(
// use tensor
framework::DDim ddim = framework::make_ddim(dims);
framework::Tensor input_tensor;
framework::LoDTensor input_lod_tensor;
paddle_mobile::framework::LoD lod{{}};
for (int i = 0; i < input.lod.size(); ++i) {
lod[0].push_back(input.lod[i]);
}
input_lod_tensor.set_lod(lod);
int input_length = framework::product(ddim);
if (input.lod.size() > 0) {
framework::LoDTensor input_lod_tensor;
paddle_mobile::framework::LoD lod{{}};
for (int i = 0; i < input.lod.size(); ++i) {
lod[0].push_back(input.lod[i]);
}
input_lod_tensor.set_lod(lod);
input_lod_tensor.Resize(ddim);
memcpy(input_lod_tensor.mutable_data<T>(),
static_cast<T *>(input.data.data()), input_length * sizeof(T));
if (input.dtype == UINT8) {
memcpy(input_lod_tensor.mutable_data<uint8_t>(),
static_cast<uint8_t *>(input.data.data()),
input_length * sizeof(uint8_t));
} else {
memcpy(input_lod_tensor.mutable_data<T>(),
static_cast<T *>(input.data.data()), input_length * sizeof(T));
}
paddle_mobile_->Predict(input_lod_tensor);
} else {
input_tensor.Resize(ddim);
memcpy(input_tensor.mutable_data<T>(), static_cast<T *>(input.data.data()),
input_length * sizeof(T));
paddle_mobile_->Predict(input_tensor);
if (input.dtype == UINT8) {
framework::Tensor input_tensor(static_cast<uint8_t *>(input.data.data()),
ddim);
paddle_mobile_->Predict(input_tensor);
} else {
framework::Tensor input_tensor(static_cast<T *>(input.data.data()), ddim);
paddle_mobile_->Predict(input_tensor);
}
}
auto output_tensor = paddle_mobile_->Fetch();
......@@ -124,12 +134,21 @@ bool PaddleMobilePredictor<Device, T>::Run(
output.shape.push_back(static_cast<int>(d));
}
if (output.data.length() < output_length * sizeof(T)) {
output.data.Resize(output_length * sizeof(T));
}
if (output.dtype == UINT8) {
if (output.data.length() < output_length * sizeof(uint8_t)) {
output.data.Resize(output_length * sizeof(uint8_t));
}
memcpy(output.data.data(), output_tensor->template data<T>(),
output_length * sizeof(T));
memcpy(output.data.data(), output_tensor->template data<uint8_t>(),
output_length * sizeof(uint8_t));
} else {
if (output.data.length() < output_length * sizeof(T)) {
output.data.Resize(output_length * sizeof(T));
}
memcpy(output.data.data(), output_tensor->template data<T>(),
output_length * sizeof(T));
}
return true;
}
......
......@@ -48,6 +48,7 @@ enum PaddleDType {
FLOAT16,
INT64,
INT8,
UINT8,
};
enum LayoutType {
......@@ -206,9 +207,11 @@ struct PaddleModelMemoryPack {
struct PaddleMobileConfig : public PaddlePredictor::Config {
enum Precision { FP32 = 0 };
enum Device { kCPU = 0, kFPGA = 1, kGPU_MALI = 2, kGPU_CL = 3 };
enum PrePostType { NONE_PRE_POST = 0, UINT8_255 = 1 };
enum Precision precision;
enum Device device;
enum PrePostType pre_post_type;
int batch_size = 1;
bool optimize = true;
......
......@@ -60,3 +60,51 @@ __kernel void feed(__global float *in,
write_imageh(output_image, output_pos, output);
}
__kernel void feed_with_pre(__global uchar *in,
__write_only image2d_t output_image,
__private const int out_H,
__private const int out_W,
__private const int out_C,
__private const int Stride0,
__private const int Stride1,
__private const int Stride2){
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int out_n = out_nh/out_H;
const int out_h = out_nh%out_H;
const int in_n = out_n;
const int in_c0 = out_c * 4 + 0;
const int in_c1 = out_c * 4 + 1;
const int in_c2 = out_c * 4 + 2;
const int in_c3 = out_c * 4 + 3;
const int in_h = out_h;
const int in_w = out_w;
int input_pos0 = in_n * Stride2 + in_c0 * Stride1 + in_h * Stride0 + in_w;
int input_pos1 = in_n * Stride2 + in_c1 * Stride1 + in_h * Stride0 + in_w;
int input_pos2 = in_n * Stride2 + in_c2 * Stride1 + in_h * Stride0 + in_w;
int input_pos3 = in_n * Stride2 + in_c3 * Stride1 + in_h * Stride0 + in_w;
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
half4 output = (half4)0.0f;
output.x = convert_half(in[input_pos0]) / 255;
if(out_C - 4 * out_c>=2){
output.y = convert_half(in[input_pos1]) / 255;
}
if(out_C - 4 * out_c>=3){
output.z = convert_half(in[input_pos2]) / 255;
}
if(out_C - 4 * out_c>=4){
output.w = convert_half(in[input_pos3]) / 255;
}
write_imageh(output_image, output_pos, output);
}
......@@ -67,3 +67,38 @@ __kernel void fetch_2d(__private const int in_height,
out[index + 2] = convert_float(in.z);
out[index + 3] = convert_float(in.w);
}
__kernel void fetch_with_post(__private const int in_height,
__private const int in_width,
__read_only image2d_t input,
__global uchar* out,
__private const int size_ch,
__private const int size_block,
__private const int size_batch,
__private const int C) {
const int in_c = get_global_id(0);
const int in_w = get_global_id(1);
const int in_nh = get_global_id(2);
const int in_n = in_nh / in_height;
const int in_h = in_nh % in_height;
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
const int pos_x = mad24(in_c, in_width, in_w);
half4 in = read_imageh(input, sampler, (int2)(pos_x, in_nh));
const int index = in_n * size_batch + in_c * size_block + in_h * in_width + in_w;
out[index] = convert_uchar_sat(in.x * 255);
if(C - 4 * in_c>=2){
out[index + size_ch] = convert_uchar_sat(in.y * 255);
}
if(C - 4 * in_c>=3){
out[index + size_ch * 2] = convert_uchar_sat(in.z * 255);
}
if(C - 4 * in_c>=4){
out[index + size_ch * 3] = convert_uchar_sat(in.w * 255);
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void pre(__global const uchar *input,
__global float *output){
int index = get_global_id(0);
output[index] = convert_float(input[index]) / 255;
}
......@@ -21,7 +21,11 @@ namespace operators {
template <>
bool FeedKernel<GPU_CL, float>::Init(FeedParam<GPU_CL> *param) {
DLOG << "Init feed";
this->cl_helper_.AddKernel("feed", "feed_kernel.cl");
if (this->pre_post_type_ == UINT8_255) {
this->cl_helper_.AddKernel("feed_with_pre", "feed_kernel.cl");
} else {
this->cl_helper_.AddKernel("feed", "feed_kernel.cl");
}
return true;
}
......@@ -34,7 +38,7 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
auto output = param.Out();
const Tensor *input = &param.InputX()->at(col);
// DLOG << *input;
const float *input_data = input->data<float>();
int numel = input->numel();
cl_mem output_image = output->GetCLImage();
const int out_C = output->dims()[1];
......@@ -46,7 +50,14 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
framework::CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
input_cl_tensor.Resize(input->dims());
cl_mem inputBuffer = input_cl_tensor.mutable_with_data<float>(input_data);
cl_mem inputBuffer;
if (this->pre_post_type_ == UINT8_255) {
inputBuffer =
input_cl_tensor.mutable_with_data<uint8_t>(input->data<uint8_t>());
} else {
inputBuffer =
input_cl_tensor.mutable_with_data<float>(input->data<float>());
}
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputBuffer);
CL_CHECK_ERRORS(status);
......
......@@ -20,7 +20,11 @@ namespace operators {
template <>
bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
if (this->pre_post_type_ == UINT8_255) {
this->cl_helper_.AddKernel("fetch_with_post", "fetch_kernel.cl");
} else {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
}
return true;
}
......@@ -33,7 +37,6 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
auto input = param.InputX()->GetCLImage();
auto *out = &param.Out()->at(col);
out->Resize(param.InputX()->dims());
out->mutable_data<float>();
DLOG << "fetch kernel out dims = " << out->dims();
DLOG << "fetch kernel out memory size = " << out->memory_size();
......@@ -57,7 +60,14 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
framework::CLTensor out_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
out_cl_tensor.Resize(out->dims());
cl_mem outBuffer = out_cl_tensor.mutable_data<float>();
cl_mem outBuffer;
if (this->pre_post_type_ == UINT8_255) {
out->mutable_data<uint8_t>();
outBuffer = out_cl_tensor.mutable_data<uint8_t>();
} else {
out->mutable_data<float>();
outBuffer = out_cl_tensor.mutable_data<float>();
}
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &in_height);
......@@ -91,8 +101,13 @@ void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
DLOG << "fetch kernel out_cl_tensor dims = " << out_cl_tensor.dims();
DLOG << "fetch kernel out_cl_tensor memery size = "
<< out_cl_tensor.memory_size();
memcpy(out->data<float>(), out_cl_tensor.Data<float>(),
sizeof(float) * out->numel());
if (this->pre_post_type_ == UINT8_255) {
memcpy(out->data<uint8_t>(), out_cl_tensor.Data<uint8_t>(),
sizeof(uint8_t) * out->numel());
} else {
memcpy(out->data<float>(), out_cl_tensor.Data<float>(),
sizeof(float) * out->numel());
}
}
template class FetchKernel<GPU_CL, float>;
......
......@@ -539,4 +539,7 @@ else()
# gen test
ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-net paddle-mobile)
ADD_EXECUTABLE(test-inference-pre-post net/test_inference_pre_post.cpp)
target_link_libraries(test-inference-pre-post paddle-mobile)
endif()
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include "../test_helper.h"
#include "io/paddle_inference_api.h"
using namespace paddle_mobile; // NOLINT
PaddleMobileConfig GetConfig() {
PaddleMobileConfig config;
config.precision = PaddleMobileConfig::FP32;
config.device = PaddleMobileConfig::kGPU_CL;
config.pre_post_type = PaddleMobileConfig::UINT8_255;
config.prog_file = "../models/superv2/model";
config.param_file = "../models/superv2/params";
config.lod_mode = false;
config.load_when_predict = true;
config.cl_path = "/data/local/tmp/bin";
return config;
}
int main() {
PaddleMobileConfig config = GetConfig();
auto predictor =
CreatePaddlePredictor<PaddleMobileConfig,
PaddleEngineKind::kPaddleMobile>(config);
int input_length = 1 * 1 * 300 * 300;
int output_length = input_length;
uint8_t data_ui[300 * 300];
for (int i = 0; i < input_length; ++i) {
data_ui[i] = i % 256;
}
PaddleTensor input;
input.shape = std::vector<int>({1, 1, 300, 300});
input.data = PaddleBuf(data_ui, sizeof(data_ui));
input.dtype = PaddleDType::UINT8;
input.layout = LayoutType::LAYOUT_CHW;
std::vector<PaddleTensor> inputs(1, input);
PaddleTensor output;
output.shape = std::vector<int>({});
output.data = PaddleBuf();
output.dtype = PaddleDType::UINT8;
output.layout = LayoutType::LAYOUT_CHW;
std::vector<PaddleTensor> outputs(1, output);
std::cout << " print input : " << std::endl;
int stride = input_length / 20;
stride = stride > 0 ? stride : 1;
for (size_t j = 0; j < input_length; j += stride) {
std::cout << (unsigned)data_ui[j] << " ";
}
std::cout << std::endl;
predictor->Run(inputs, &outputs);
std::cout << " print output : " << std::endl;
uint8_t *data_o = static_cast<uint8_t *>(outputs[0].data.data());
int numel = outputs[0].data.length() / sizeof(uint8_t);
stride = numel / 20;
stride = stride > 0 ? stride : 1;
for (size_t j = 0; j < numel; j += stride) {
std::cout << (unsigned)data_o[j] << " ";
}
std::cout << std::endl;
return 0;
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册