提交 b1e877e9 编写于 作者: D dolphin8

merge

......@@ -16,7 +16,6 @@ file(GLOB_RECURSE PADDLE_MOBILE_CC src/*.cc src/*.cpp src/*.c src/*.mm)
file(GLOB_RECURSE PADDLE_MOBILE_H src/*.h)
include_directories(src/)
if(IS_IOS)
set(CMAKE_CXX_FLAGS "-mfpu=neon -marm -fobjc-abi-version=2 -fobjc-arc -std=gnu++11 -stdlib=libc++ -O3 -s -isysroot ${CMAKE_OSX_SYSROOT} ${CMAKE_CXX_FLAGS}")
else()
......@@ -145,16 +144,16 @@ endif()
if (ANDROID_NDK_TOOLCHAIN_INCLUDED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -llog")
else()
list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/jni/paddle_mobile_jni.h)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/jni/paddle_mobile_jni.cpp)
list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/io/jni/paddle_mobile_jni.h)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/io/jni/paddle_mobile_jni.cpp)
list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/operators/math/math_func_neon.h)
endif ()
if (IS_IOS)
else()
list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/ios_io/PaddleMobileCPU.h)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/ios_io/PaddleMobileCPU.mm)
list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/ios_io/op_symbols.h)
list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/io/ios_io/PaddleMobileCPU.h)
list(REMOVE_ITEM PADDLE_MOBILE_CC ${CMAKE_CURRENT_SOURCE_DIR}/src/io/ios_io/PaddleMobileCPU.mm)
list(REMOVE_ITEM PADDLE_MOBILE_H ${CMAKE_CURRENT_SOURCE_DIR}/src/io/ios_io/op_symbols.h)
endif ()
set(CMAKE_VERBOSE_MAKEFILE ON)
......
......@@ -16,6 +16,9 @@ limitations under the License. */
#include "framework/cl/cl_half.h"
namespace paddle_mobile {
namespace framework {
static const uint32_t mantissatable[2048] = {
0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34a00000,
0x34c00000, 0x34e00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000,
......@@ -510,3 +513,6 @@ void HalfArray2FloatArray(half_t *h_array, float *f_array, int count) {
f_array[i] = Half2Float(h_array[i]);
}
}
} // namespace framework
} // namespace paddle_mobile
......@@ -15,6 +15,9 @@ limitations under the License. */
#pragma once
#include <cstdint>
namespace paddle_mobile {
namespace framework {
typedef uint16_t half_t;
half_t Float2Half(float f);
......@@ -24,3 +27,6 @@ float Half2Float(half_t h);
void FloatArray2HalfArray(float *f_array, half_t *h_array, int count);
void HalfArray2FloatArray(half_t *h_array, float *f_array, int count);
} // namespace framework
} // namespace paddle_mobile
......@@ -64,6 +64,16 @@ class CLHelper {
auto work_size_2 = n * h;
return {work_size_0, work_size_1, work_size_2};
} else if (image_dim.size() == 2) {
auto image_width = image.ImageWidth();
auto work_size_0 = image_width / image_dim[1];
auto work_size_1 = image_dim[1];
auto work_size_2 = image_dim[0];
return {work_size_0, work_size_1, work_size_2};
}
PADDLE_MOBILE_THROW_EXCEPTION("not support this dim, need imp");
......
......@@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "cl_image.h"
#include "framework/cl/cl_image.h"
namespace paddle_mobile {
namespace framework {
void CLImageToTensor(CLImage *cl_image, Tensor *tensor,
......@@ -37,7 +38,7 @@ void CLImageToTensor(CLImage *cl_image, Tensor *tensor,
size_t width = W * ((C + 3) / 4);
size_t height = H * N;
float *p = tensor->data<float>();
float *p = tensor->mutable_data<float>();
half imageData[width * height * 4];
cl_int err;
cl_mem image = cl_image->GetCLImage();
......@@ -63,7 +64,7 @@ void CLImageToTensor(CLImage *cl_image, Tensor *tensor,
}
if (err != CL_SUCCESS) {
// TODO: error handling
CL_CHECK_ERRORS(err);
}
}
void TensorToCLImage(const Tensor *tensor, CLImage *cl_image,
......@@ -97,7 +98,7 @@ void TensorToCLImage(const Tensor *tensor, CLImage *cl_image,
err = clEnqueueReadImage(commandQueue, image, CL_TRUE, origin, region, 0, 0,
imageData, 0, NULL, NULL);
if (err != CL_SUCCESS) {
// TODO: error handling
CL_CHECK_ERRORS(err);
}
size_t i0 = 0;
for (int n = 0; n < N; n++) {
......@@ -116,5 +117,64 @@ void TensorToCLImage(const Tensor *tensor, CLImage *cl_image,
i0 += width * H;
}
}
#ifdef PADDLE_MOBILE_DEBUG
Print &operator<<(Print &printer, const CLImage &cl_image) {
printer << " dims: " << cl_image.dims() << "\n";
int stride = cl_image.numel() / 20;
stride = stride > 0 ? stride : 1;
float *data = new float[cl_image.numel()];
DDim ddim = cl_image.dims();
size_t N, C, H, W;
if (ddim.size() == 4) {
N = ddim[0];
if (N < 0) {
N = 1;
}
C = ddim[1];
H = ddim[2];
W = ddim[3];
} else if (ddim.size() == 1) {
N = 1;
C = ddim[0];
H = 1;
W = 1;
}
size_t width = W * ((C + 3) / 4);
size_t height = H * N;
float *p = data;
half imageData[width * height * 4];
cl_int err;
cl_mem image = cl_image.GetCLImage();
size_t origin[3] = {0, 0, 0};
size_t region[3] = {width, height, 1};
err = clEnqueueReadImage(cl_image.CommandQueue(), image, CL_TRUE, origin,
region, 0, 0, imageData, 0, NULL, NULL);
size_t i0 = 0;
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
size_t i1 = i0;
for (int h = 0; h < H; h++) {
size_t i2 = (i1 << 2) + c % 4;
for (int w = 0; w < W; w++) {
*p = Half2Float(imageData[i2]);
i2 += 4;
p++;
}
i1 += width;
}
}
i0 += width * H;
}
CL_CHECK_ERRORS(err);
for (int i = 0; i < cl_image.numel(); i += stride) {
printer << data[i] << " ";
}
return printer;
}
#endif
} // namespace framework
} // namespace paddle_mobile
......@@ -46,27 +46,28 @@ class CLImage {
/*
* need call SetTensorData first
* */
void InitCLImage(cl_context context) {
void InitCLImage(cl_context context, cl_command_queue command_queue) {
if (tensor_data_ == nullptr) {
PADDLE_MOBILE_THROW_EXCEPTION(" need call SetTensorData first");
}
if (tensor_dims_.size() <= 2) {
InitCLImage2C(context, tensor_data_, tensor_dims_);
InitCLImage2C(context, command_queue, tensor_data_, tensor_dims_);
} else {
InitCLImage(context, tensor_data_, tensor_dims_);
InitCLImage(context, command_queue, tensor_data_, tensor_dims_);
}
delete[](tensor_data_);
tensor_data_ = nullptr;
initialized_ = true;
}
void InitEmptyImage(cl_context context, const DDim &dim) {
void InitEmptyImage(cl_context context, cl_command_queue command_queue,
const DDim &dim) {
if (tensor_data_ != nullptr) {
PADDLE_MOBILE_THROW_EXCEPTION(
" empty image tensor data shouldn't have value");
}
DLOG << " init empty image ";
InitCLImage(context, nullptr, dim);
InitCLImage(context, command_queue, nullptr, dim);
initialized_ = true;
}
......@@ -93,6 +94,8 @@ class CLImage {
* */
inline size_t HeightOfOneBlock() const { return height_of_one_block_; }
inline cl_command_queue CommandQueue() const { return command_queue_; }
/*
* resize original tensor dim
* */
......@@ -122,7 +125,9 @@ class CLImage {
const DDim &dims() const { return tensor_dims_; }
private:
void InitCLImage2C(cl_context context, float *tensor_data, const DDim &dim) {
void InitCLImage2C(cl_context context, cl_command_queue command_queue,
float *tensor_data, const DDim &dim) {
command_queue_ = command_queue;
assert(dim.size() <= 2);
int tdim[2] = {1, 1};
if (dim.size() == 1) {
......@@ -138,7 +143,8 @@ class CLImage {
imageData.reset(new half_t[width * height * 4]);
for (int h = 0; h < tdim[0]; h++) {
for (int w = 0; w < tdim[1]; w++) {
imageData[(h * width + w / 4) * 4 + (w % 4)] = Float2Half(tensor_data[h * tdim[1] + w]);
imageData[(h * width + w / 4) * 4 + (w % 4)] =
Float2Half(tensor_data[h * tdim[1] + w]);
}
}
}
......@@ -149,35 +155,36 @@ class CLImage {
cl_image_format cf = {.image_channel_order = CL_RGBA,
.image_channel_data_type = CL_HALF_FLOAT};
cl_image_desc cid = {
.image_type = CL_MEM_OBJECT_IMAGE2D,
.image_width = width,
.image_height = height,
.image_depth = 1,
.image_array_size = 1,
.image_row_pitch = 0,
.image_slice_pitch = 0,
.num_mip_levels = 0,
.num_samples = 0,
// .buffer = nullptr
.image_type = CL_MEM_OBJECT_IMAGE2D,
.image_width = width,
.image_height = height,
.image_depth = 1,
.image_array_size = 1,
.image_row_pitch = 0,
.image_slice_pitch = 0,
.num_mip_levels = 0,
.num_samples = 0,
// .buffer = nullptr
};
cid.buffer = nullptr;
cl_int err;
cl_image_ = clCreateImage(
context, CL_MEM_READ_WRITE | (data ? CL_MEM_COPY_HOST_PTR : 0),
&cf, // const cl_image_format *image_format
&cid, // const cl_image_desc *image_desc
data, // void *host_ptr
&err
);
context, CL_MEM_READ_WRITE | (data ? CL_MEM_COPY_HOST_PTR : 0),
&cf, // const cl_image_format *image_format
&cid, // const cl_image_desc *image_desc
data, // void *host_ptr
&err);
if (err != CL_SUCCESS) {
CL_CHECK_ERRORS(err);
PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error ");
}
}
void InitCLImage(cl_context context, float *tensor_data, const DDim &dim) {
void InitCLImage(cl_context context, cl_command_queue command_queue,
float *tensor_data, const DDim &dim) {
DLOG << " tensor dim: " << dim;
// NCHW -> [W * (C+3)/4, H * N]
tensor_dims_ = dim;
command_queue_ = command_queue;
if (tensor_data) {
tensor_data_ = tensor_data;
}
......@@ -203,6 +210,7 @@ class CLImage {
image_width_ = width;
image_height_ = height;
image_dims_ = make_ddim({image_width_, image_height_});
c_block_ = W / width;
std::unique_ptr<half_t[]> imageData{};
int count = 0;
......@@ -241,6 +249,7 @@ class CLImage {
DDim image_dims_;
float *tensor_data_;
cl_context context_;
cl_command_queue command_queue_;
};
void TensorToCLImage(Tensor *tensor, CLImage *image,
......
......@@ -28,7 +28,19 @@ namespace framework {
class CLTensor : TensorBase {
public:
explicit CLTensor(cl_context context) : context_(context) {}
CLTensor(cl_context context, cl_command_queue command_queue)
: context_(context), command_queue_(command_queue) {}
CLTensor() = default;
/*
* if init method haven't set context and command_queue, need set
* */
void SetContextAndCommandQueue(cl_context context,
cl_command_queue command_queue) {
context_ = context;
command_queue_ = command_queue;
}
/*! Resize the dimensions of the memory block. */
inline CLTensor &Resize(const DDim &dims) {
......@@ -39,7 +51,8 @@ class CLTensor : TensorBase {
template <typename T>
inline T mutable_with_data(void *data) {
int64_t size = numel() * sizeof(float);
holder_.reset(new PlaceholderImpl(size, data, typeid(T), context_));
holder_.reset(
new PlaceholderImpl(size, data, typeid(T), context_, command_queue_));
return reinterpret_cast<T>(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(holder_->ptr())));
}
......@@ -51,7 +64,7 @@ class CLTensor : TensorBase {
PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor's numel must >=0.")
int64_t size = numel() * SizeOfType(type);
if (holder_ == nullptr || holder_->size() < size + offset_) {
holder_.reset(new PlaceholderImpl(size, type, context_));
holder_.reset(new PlaceholderImpl(size, type, context_, command_queue_));
offset_ = 0;
}
return reinterpret_cast<void *>(
......@@ -85,6 +98,7 @@ class CLTensor : TensorBase {
private:
cl_context context_;
cl_command_queue command_queue_;
/*
* virtual ~Placeholder() = default;
......@@ -99,20 +113,31 @@ class CLTensor : TensorBase {
* */
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, void *input, std::type_index type,
cl_context context)
cl_context context, cl_command_queue command_queue)
: ptr_(clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
size, reinterpret_cast<void *>(input), NULL)),
size_(size),
type_(type) {}
type_(type),
command_queue_(command_queue) {}
PlaceholderImpl(size_t size, std::type_index type, cl_context context)
PlaceholderImpl(size_t size, std::type_index type, cl_context context,
cl_command_queue command_queue)
: ptr_(clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, NULL)),
size_(size),
type_(type) {}
type_(type),
command_queue_(command_queue) {}
virtual size_t size() const { return size_; }
virtual void *ptr() const { return static_cast<void *>(ptr_.get()); }
virtual void *ptr() const {
if (host_ptr_) {
delete (host_ptr_);
}
char *host_ptr = new char[size_];
clEnqueueReadBuffer(command_queue_, ptr_.get(), CL_TRUE, 0, size_,
host_ptr, 0, NULL, NULL);
return static_cast<void *>(host_ptr);
}
virtual std::type_index type() const { return type_; }
......@@ -124,6 +149,17 @@ class CLTensor : TensorBase {
/* the current type of memory */
std::type_index type_;
cl_command_queue command_queue_;
~PlaceholderImpl() {
if (host_ptr_) {
delete (host_ptr_);
}
}
private:
void *host_ptr_;
};
};
......
......@@ -87,7 +87,7 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
for (int i = 0; i < blocks.size(); ++i) {
std::shared_ptr<framework::BlockDesc> block_desc = blocks[i];
std::vector<std::shared_ptr<framework::OpDesc>> ops = block_desc->Ops();
for (int j = 0; j < debug_to; ++j) {
for (int j = 0; j < ops.size(); ++j) {
std::shared_ptr<framework::OpDesc> op = ops[j];
DLOG << "create op: " << j << " " << op->Type();
auto op_base = framework::OpRegistry<Dtype>::CreateOp(
......@@ -416,7 +416,7 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
}
}
#else
for (int i = 0; i < debug_to; i++) {
for (int i = 0; i < ops.size(); i++) {
#ifdef PADDLE_MOBILE_PROFILE
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
......@@ -953,12 +953,14 @@ void Executor<GPU_CL, Precision::FP32>::InitMemory() {
if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) {
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
cl_command_queue command_queue =
program_.scope->GetCLScpoe()->CommandQueue();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
// framework::DDim ddim = framework::make_ddim(desc.Dims());
framework::DDim ddim = cl_image->dims();
DLOG << var_desc->Name();
cl_image->InitEmptyImage(context, ddim);
cl_image->InitEmptyImage(context, command_queue, ddim);
}
}
}
......@@ -1010,11 +1012,12 @@ void Executor<GPU_CL, Precision::FP32>::InitCombineMemory() {
} else {
auto cl_image = var->template GetMutable<framework::CLImage>();
cl_context context = program_.scope->GetCLScpoe()->Context();
cl_command_queue command_queue =
program_.scope->GetCLScpoe()->CommandQueue();
const framework::TensorDesc &desc = var_desc->Tensor_desc();
framework::DDim ddim = cl_image->dims();
// framework::DDim ddim = framework::make_ddim(desc.Dims());
cl_image->InitEmptyImage(context, ddim);
cl_image->InitEmptyImage(context, command_queue, ddim);
}
}
}
......
......@@ -57,10 +57,9 @@ void OperatorBase<Dtype>::CheckAllInputOutputSet() const {}
template <typename Dtype>
void OperatorBase<Dtype>::Run() {
DLOG << " begin run " << type_;
DLOG << " ----- Begin run impl --- " << type_ << " ----- ";
RunImpl();
DLOG << " end run " << type_;
return;
DLOG << " ----- End run impl --- " << type_ << " ----- ";
#ifdef PADDLE_MOBILE_DEBUG
DLOG << "-------------" << type_ << "----------------------------";
vector<string> input_keys = GetInputKeys();
......@@ -75,16 +74,8 @@ void OperatorBase<Dtype>::Run() {
if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor;
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
// cl_command_queue commandQueue =
// scope_->GetCLScpoe()->CommandQueue(); Tensor
// *tmp ;
// CLImageToTensor(cl_image,tmp,commandQueue);
// tmp->Resize(cl_image->dims());
const float *input = cl_image->data<float>();
if (cl_image) {
DLOG << type_ << " input- " << key << "=" << cl_image->dims();
// if(input)
// DLOG<<type_<<" input- "<<key<<"="<<*input;
DLOG << type_ << " input- " << key << "=" << *cl_image;
}
}
......@@ -108,15 +99,8 @@ void OperatorBase<Dtype>::Run() {
}
} else {
CLImage *cl_image = vari->template GetMutable<framework::CLImage>();
// cl_command_queue commandQueue =
// scope_->GetCLScpoe()->CommandQueue(); Tensor *tmp ;
// CLImageToTensor(cl_image,tmp,commandQueue);
// tmp->Resize(cl_image->dims());
if (cl_image) {
const float *output = cl_image->data<float>();
DLOG << type_ << " output- " << key << "=" << cl_image->dims();
// if(output)
// DLOG<<type_<<" output- "<<key<<"="<<*output;
DLOG << type_ << " output- " << key << "=" << *cl_image;
}
}
......
......@@ -49,11 +49,13 @@ bool BatchNormKernel<GPU_CL, float>::Init(BatchNormParam<GPU_CL> *param) {
framework::CLImage *new_scale = new framework::CLImage();
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
......
......@@ -12,11 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once;
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
inline hafl4 activation(half4 in
inline half4 activation(half4 in
#ifdef PRELU
,
half4 prelu_alpha
......@@ -28,7 +26,7 @@ inline hafl4 activation(half4 in
#endif
#ifdef RELU
fmax(in, 0.0);
output = fmax(in, (half4)(0.0f));
#endif
return output;
}
......@@ -12,10 +12,328 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define BIASE
#define BATCH_NORM
#define RELU
#include "conv_kernel.inc.cl"
#undef
#undef
#undef
#include "cl_kernel/cl_common.h"
__kernel void conv_3x3(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int2 stride_xy;
stride_xy.x = stride;
stride_xy.y = stride;
int2 ouput_pos_in_one_block;
ouput_pos_in_one_block.x = out_w;
ouput_pos_in_one_block.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 in_pos_in_one_block;
in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset;
in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset;
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
half4 input[9];
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
input[0] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height));
input[1] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height));
input[2] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height));
input[3] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height));
input[4] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height));
input[5] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height));
input[6] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height));
input[7] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height));
input[8] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)(pos_in.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || pos_in.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height));
for (int j = 0; j < 9; ++j) {
int2 fuck;
fuck.x = i * 3 + j % 3;
fuck.y = out_c * 4 * 3 + 0 * out_c * 3 + j / 3;
half4 weight_x = read_imageh(filter, sampler, fuck);
output.x += dot(input[j], weight_x);
fuck.y = out_c * 4 * 3 + 1 * out_c * 3 + j / 3;
half4 weight_y = read_imageh(filter, sampler, fuck);
output.y += dot(input[j], weight_y);
fuck.y = out_c * 4 * 3 + 2 * out_c * 3 + j / 3;
half4 weight_z = read_imageh(filter, sampler, fuck);
output.z += dot(input[j], weight_z);
fuck.y = out_c * 4 * 3 + 3 * out_c * 3 + j / 3;
half4 weight_w = read_imageh(filter, sampler, fuck);
output.w += dot(input[j], weight_w);
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
__kernel void depth_conv_3x3(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
const int batch_index = out_nh / output_height;
const int out_nh_in_one_batch = out_nh % output_height;
const uint kernelHXW = 1;
int2 stride_xy = (int2)(stride, stride);
int2 ouput_pos_in_one_block = (int2)(out_w, out_nh_in_one_batch);
int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset);
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
int2 pos_in_input_block = (int2)(out_c * input_width, batch_index * input_height);
int weight_x_to = out_c * 3;
half4 inputs[9];
inputs[0] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height));
inputs[1] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y - 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - 1 >= input_height));
inputs[2] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height));
inputs[3] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y >= input_height));
inputs[4] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height));
inputs[5] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y >= input_height));
inputs[6] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height));
inputs[7] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y + 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + 1 >= input_height));
inputs[8] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height));
for (int j = 0; j < 9; ++j) {
half4 input = inputs[j];
half4 weight = read_imageh(filter, sampler, (int2)(weight_x_to + j % 3, j / 3));
output.x += input.x * weight.x;
output.y += input.y * weight.y;
output.z += input.z * weight.z;
output.w += input.w * weight.w;
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
write_imageh(output_image, output_pos, output);
}
__kernel void conv_1x1(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
const uint kernelHXW = 1;
int2 stride_xy = (int2)(stride, stride);
int2 ouput_pos_in_one_block = (int2)(out_w, out_nh);
int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset);
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
if (pos_in.x >=0 && pos_in.y >= 0 && pos_in.x < input_width && pos_in.y < input_height) {
half4 input = read_imageh(input_image, sampler, pos_in);
half4 weight_x = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 0));
output.x += dot(input, weight_x);
half4 weight_y = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 1));
output.y += dot(input, weight_y);
half4 weight_z = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 2));
output.z += dot(input, weight_z);
half4 weight_w = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 3));
output.w += dot(input, weight_w);
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
write_imageh(output_image, output_pos, output);
}
......@@ -12,6 +12,324 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define BIASE
#include "conv_kernel.inc.cl"
#undef
__kernel void conv_3x3(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int2 stride_xy;
stride_xy.x = stride;
stride_xy.y = stride;
int2 ouput_pos_in_one_block;
ouput_pos_in_one_block.x = out_w;
ouput_pos_in_one_block.y = out_nh;
int2 in_pos_in_one_block;
in_pos_in_one_block.x = ouput_pos_in_one_block.x * stride + offset;
in_pos_in_one_block.y = ouput_pos_in_one_block.y * stride + offset;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
half4 input[9];
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
input[0] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height));
input[1] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - dilation >= input_height));
input[2] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y - dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y - dilation < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y - dilation >= input_height));
input[3] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y >= input_height));
input[4] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height));
input[5] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + dilation < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + dilation >= input_width || in_pos_in_one_block.y >= input_height));
input[6] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x - dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - dilation < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x - dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height));
input[7] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + dilation < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + dilation >= input_height));
input[8] = select(read_imageh(input_image, sampler,
(int2)(pos_in.x + dilation, pos_in.y + dilation)),
(half4)(0.0f),
(ushort4)(pos_in.x + dilation < 0 || in_pos_in_one_block.y + dilation < 0 || pos_in.x + dilation >= input_width || in_pos_in_one_block.y + dilation >= input_height));
for (int j = 0; j < 9; ++j) {
int2 fuck;
fuck.x = i * 3 + j % 3;
fuck.y = out_c * 4 * 3 + 0 * out_c * 3 + j / 3;
half4 weight_x = read_imageh(filter, sampler, fuck);
output.x += dot(input[j], weight_x);
fuck.y = out_c * 4 * 3 + 1 * out_c * 3 + j / 3;
half4 weight_y = read_imageh(filter, sampler, fuck);
output.y += dot(input[j], weight_y);
fuck.y = out_c * 4 * 3 + 2 * out_c * 3 + j / 3;
half4 weight_z = read_imageh(filter, sampler, fuck);
output.z += dot(input[j], weight_z);
fuck.y = out_c * 4 * 3 + 3 * out_c * 3 + j / 3;
half4 weight_w = read_imageh(filter, sampler, fuck);
output.w += dot(input[j], weight_w);
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
write_imageh(output_image, (int2)(out_c * global_size_dim1 + out_w, out_nh), output);
}
__kernel void depth_conv_3x3(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height, /* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
const int batch_index = out_nh / output_height;
const int out_nh_in_one_batch = out_nh % output_height;
const uint kernelHXW = 1;
int2 stride_xy = (int2)(stride, stride);
int2 ouput_pos_in_one_block = (int2)(out_w, out_nh_in_one_batch);
int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset);
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
int2 pos_in_input_block = (int2)(out_c * input_width, batch_index * input_height);
int weight_x_to = out_c * 3;
half4 inputs[9];
inputs[0] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height));
inputs[1] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y - 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y - 1 >= input_height));
inputs[2] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y - 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y - 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y - 1 >= input_height));
inputs[3] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y >= input_height));
inputs[4] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y >= input_height));
inputs[5] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y >= input_height));
inputs[6] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x - 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x - 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x - 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height));
inputs[7] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x, pos_in_input_block.y + in_pos_in_one_block.y + 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x >= input_width || in_pos_in_one_block.y + 1 >= input_height));
inputs[8] = select(read_imageh(input, sampler, (int2)(pos_in_input_block.x + in_pos_in_one_block.x + 1, pos_in_input_block.y + in_pos_in_one_block.y + 1)),
(half4)(0.0f),
(ushort4)(in_pos_in_one_block.x + 1 < 0 || in_pos_in_one_block.y + 1 < 0 || in_pos_in_one_block.x + 1 >= input_width || in_pos_in_one_block.y + 1 >= input_height));
for (int j = 0; j < 9; ++j) {
half4 input = inputs[j];
half4 weight = read_imageh(filter, sampler, (int2)(weight_x_to + j % 3, j / 3));
output.x += input.x * weight.x;
output.y += input.y * weight.y;
output.z += input.z * weight.z;
output.w += input.w * weight.w;
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
write_imageh(output_image, output_pos, output);
}
__kernel void conv_1x1(__private const int global_size_dim0,
__private const int global_size_dim1,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
#ifdef BATCH_NORM
__read_only image2d_t new_scale,
__read_only image2d_t new_biase,
#endif
__write_only image2d_t output_image,
__private const int stride,
__private const int offset,
__private const int input_c,
__private const int dilation,
__private const int input_width,/* of one block */
__private const int input_height,/* of one block */
__private const int output_width,
__private const int output_height) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
const uint kernelHXW = 1;
int2 stride_xy = (int2)(stride, stride);
int2 ouput_pos_in_one_block = (int2)(out_w, out_nh);
int2 in_pos_in_one_block = ouput_pos_in_one_block * stride_xy + (int2)(offset, offset);
#ifdef BIASE
half4 output = read_imageh(bias, sampler, (int2)(out_c, 0));
#else
half4 output = 0.0f;
#endif
for (int i = 0; i < input_c; ++i) {
int2 pos_in = (int2)(i * input_width + in_pos_in_one_block.x, in_pos_in_one_block.y);
if (pos_in.x >=0 && pos_in.y >= 0 && pos_in.x < input_width && pos_in.y < input_height) {
half4 input = read_imageh(input_image, sampler, pos_in);
half4 weight_x = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 0));
output.x += dot(input, weight_x);
half4 weight_y = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 1));
output.y += dot(input, weight_y);
half4 weight_z = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 2));
output.z += dot(input, weight_z);
half4 weight_w = read_imageh(filter, sampler, (int2)(i, out_c * 4 + 3));
output.w += dot(input, weight_w);
}
}
#ifdef BATCH_NORM
output = output * read_imageh(new_scale, sampler, (int2)(out_c, 0)) + read_imageh(new_biase, sampler, (int2)(out_c, 0));
#endif
#ifdef RELU
output = activation(output);
#endif
int2 output_pos = (int2)(out_c * global_size_dim1 + out_w, out_nh);
write_imageh(output_image, output_pos, output);
}
......@@ -19,6 +19,7 @@ __kernel void conv_3x3(__private const int global_size_dim0,
__private const int global_size_dim2,
__read_only image2d_t input_image,
__read_only image2d_t filter,
#ifdef BIASE
__read_only image2d_t bias,
#endif
......
......@@ -29,8 +29,10 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init(
param->Paddings()[0] == param->Paddings()[1],
"need equal");
param->Filter()->InitCLImage(cl_helper_.CLContext());
param->Bias()->InitCLImage(cl_helper_.CLContext());
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
param->Bias()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
// const CL *mean = param->InputMean();
const framework::CLImage *mean = param->InputMean();
......@@ -38,6 +40,11 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init(
const framework::CLImage *scale = param->InputScale();
const framework::CLImage *bias = param->InputBias();
const float epsilon = param->Epsilon();
//
// DLOG << " climage mean: " << *mean;
// DLOG << " climage variance: " << *variance;
// DLOG << " climage scale: " << *scale;
// DLOG << " climage bias: " << *bias;
auto mean_ptr = mean->data<float>();
auto variance_ptr = variance->data<float>();
......@@ -62,12 +69,22 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init(
framework::CLImage *new_scale = new framework::CLImage();
new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext());
new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
DLOG << " climage - y bias: " << *(param->Bias());
DLOG << " climage - new scale: " << *new_scale;
framework::CLImage *new_bias = new framework::CLImage();
new_bias->SetTensorData(new_bias_ptr, variance->dims());
new_bias->InitCLImage(this->cl_helper_.CLContext());
new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
DLOG << " climage - new bias: " << *new_bias;
DLOG << " climage - filter: " << *(param->Filter());
param->SetNewScale(new_scale);
param->SetNewBias(new_bias);
......@@ -113,7 +130,7 @@ void ConvAddBNReluKernel<GPU_CL, float>::Compute(
auto biase = param.Bias()->GetCLImage();
auto new_scale = param.NewScale()->GetCLImage();
auto new_bias = param.NewBias()->GetCLImage();
auto output = param.Output();
auto output = param.Output()->GetCLImage();
int stride = param.Strides()[0];
int offset = param.Offset();
int input_c = param.Input()->CBlock();
......@@ -126,23 +143,54 @@ void ConvAddBNReluKernel<GPU_CL, float>::Compute(
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &new_scale);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(cl_mem), &new_bias);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 15, sizeof(int), &output_width);
status = clSetKernelArg(kernel, 16, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 16, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status =
......
......@@ -25,8 +25,10 @@ bool ConvAddKernel<GPU_CL, float>::Init(FusionConvAddParam<GPU_CL> *param) {
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
param->Filter()->InitCLImage(cl_helper_.CLContext());
param->Bias()->InitCLImage(cl_helper_.CLContext());
param->Filter()->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
param->Bias()->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
......@@ -71,27 +73,53 @@ void ConvAddKernel<GPU_CL, float>::Compute(
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(int), &w);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &nh);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(cl_mem), &input);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(cl_mem), &filter);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(cl_mem), &biase);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(cl_mem), &output);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(int), &stride);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &offset);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &input_c);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &dilation);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &input_width);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &input_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 13, sizeof(int), &output_width);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 14, sizeof(int), &output_height);
CL_CHECK_ERRORS(status);
status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
......
......@@ -26,7 +26,8 @@ bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
param->Paddings()[0] == param->Paddings()[1],
"need equal");
param->Filter()->InitCLImage(cl_helper_.CLContext());
param->Filter()->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
......@@ -95,6 +96,17 @@ void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) {
cl_int status;
DLOG << " begin set kernel arg ";
DLOG << " c block " << c_block;
DLOG << " w " << w;
DLOG << " nh " << nh;
DLOG << " stride " << stride;
DLOG << " offset " << offset;
DLOG << " input_c " << input_c;
DLOG << " dilation " << dilation;
DLOG << " input width " << input_width;
DLOG << " input height " << input_height;
DLOG << " output width " << output_width;
DLOG << " output height " << output_height;
status = clSetKernelArg(kernel, 0, sizeof(int), &c_block);
CL_CHECK_ERRORS(status);
......
......@@ -27,7 +27,8 @@ bool DepthwiseConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
param->Filter()->dims()[2] == param->Filter()->dims()[3] &&
param->Paddings()[0] == param->Paddings()[1],
"need equal");
param->Filter()->InitCLImage(cl_helper_.CLContext());
param->Filter()->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
......
......@@ -22,16 +22,16 @@ namespace operators {
template <>
bool ElementwiseAddKernel<GPU_CL, float>::Init(
ElementwiseAddParam<GPU_CL> *param) {
CLImage *bias = (CLImage*)param->InputY();
bias->InitCLImage(cl_helper_.CLContext());
if(bias->dims().size()==4){
this->cl_helper_.AddKernel("elementwise_add", "elementwise_add_kernel.cl");
}else if(param->InputY()->dims().size()==1){
DLOG<<"-----init add-----";
this->cl_helper_.AddKernel("channel_add", "channel_add_kernel.cl");
}else{
DLOG << "error:bias dims is error";
}
CLImage *bias = (CLImage *)param->InputY();
bias->InitCLImage(cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue());
if (bias->dims().size() == 4) {
this->cl_helper_.AddKernel("elementwise_add", "elementwise_add_kernel.cl");
} else if (param->InputY()->dims().size() == 1) {
DLOG << "-----init add-----";
this->cl_helper_.AddKernel("channel_add", "channel_add_kernel.cl");
} else {
DLOG << "error:bias dims is error";
}
return true;
}
......@@ -44,7 +44,7 @@ void ElementwiseAddKernel<GPU_CL, float>::Compute(
auto output = param.Out();
cl_int status;
auto kernel = this->cl_helper_.KernelAt(0);
if(bias->dims().size()==4){
if (bias->dims().size() == 4) {
cl_mem input_image = input->GetCLImage();
cl_mem bias_image = bias->GetCLImage();
cl_mem output_image = output->GetCLImage();
......@@ -57,14 +57,15 @@ void ElementwiseAddKernel<GPU_CL, float>::Compute(
int width = input->ImageWidth();
int height = input->ImageHeight();
size_t global_work_size[2] = {width, height};
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}else if(bias->dims().size()==1){
} else if (bias->dims().size() == 1) {
cl_mem input_image = input->GetCLImage();
cl_mem bias_image = bias->GetCLImage();
cl_mem output_image = output->GetCLImage();
int tensor_w = input->dims()[4];
int tensor_w = input->dims()[3];
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), (void *)&bias_image);
......@@ -76,13 +77,13 @@ void ElementwiseAddKernel<GPU_CL, float>::Compute(
int width = input->ImageWidth();
int height = input->ImageHeight();
size_t global_work_size[2] = {width, height};
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}else{
} else {
DLOG << "error:bias dims is error";
}
}
template class ElementwiseAddKernel<GPU_CL, float>;
......
......@@ -30,12 +30,14 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
cl_int status;
auto output = param.Out();
const Tensor *input = param.InputX();
DLOG << *input;
const float *input_data = input->data<float>();
int numel = input->numel();
cl_mem cl_image = output->GetCLImage();
int height = output->dims()[2];
int width = output->dims()[3];
CLTensor input_cl_tensor(this->cl_helper_.CLContext());
CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
input_cl_tensor.Resize(input->dims());
cl_mem inputBuffer =
input_cl_tensor.mutable_with_data<cl_mem>((void *)input_data);
......@@ -53,14 +55,6 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2,
NULL, global_work_size, NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
int len = 4 * 224 * 224;
half *out = new half[len];
cl_command_queue commandQueue = this->cl_helper_.CLCommandQueue();
size_t origin[3] = {0, 0, 0};
size_t region[3] = {height, width, 1};
clEnqueueReadImage(commandQueue, cl_image, CL_TRUE, origin, region, 0, 0, out,
0, NULL, NULL);
}
template class FeedKernel<GPU_CL, float>;
......
......@@ -19,44 +19,45 @@ namespace operators {
template <>
bool FetchKernel<GPU_CL, float>::Init(FetchParam<GPU_CL> *param) {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
// this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
return true;
}
template <>
void FetchKernel<GPU_CL, float>::Compute(const FetchParam<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.InputX());
auto input = param.InputX()->GetCLImage();
auto *out = param.Out();
const auto &dims = param.InputX()->dims();
const int N = dims[0];
const int C = dims[1];
const int in_height = dims[2];
const int in_width = dims[3];
int size_ch = in_height * in_width;
int size_block = size_ch * 4;
int size_batch = size_ch * C;
// need create outputBuffer
cl_image_format imageFormat;
imageFormat.image_channel_order = CL_RGBA;
imageFormat.image_channel_data_type = CL_FLOAT;
cl_mem outputBuffer;
clSetKernelArg(kernel, 0, sizeof(int), &in_height);
clSetKernelArg(kernel, 1, sizeof(int), &in_width);
clSetKernelArg(kernel, 2, sizeof(int), &size_ch);
clSetKernelArg(kernel, 3, sizeof(int), &size_block);
clSetKernelArg(kernel, 4, sizeof(int), &size_batch);
clSetKernelArg(kernel, 5, sizeof(cl_mem), &input);
clSetKernelArg(kernel, 6, sizeof(cl_mem), &outputBuffer);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
// auto kernel = this->cl_helper_.KernelAt(0);
// auto default_work_size =
// this->cl_helper_.DefaultWorkSize(*param.InputX());
//
// auto input = param.InputX()->GetCLImage();
// auto *out = param.Out();
//
// const auto &dims = param.InputX()->dims();
// const int N = dims[0];
// const int C = dims[1];
// const int in_height = dims[2];
// const int in_width = dims[3];
//
// int size_ch = in_height * in_width;
// int size_block = size_ch * 4;
// int size_batch = size_ch * C;
//
// // need create outputBuffer
// cl_image_format imageFormat;
// imageFormat.image_channel_order = CL_RGBA;
// imageFormat.image_channel_data_type = CL_FLOAT;
// cl_mem outputBuffer;
//
// clSetKernelArg(kernel, 0, sizeof(int), &in_height);
// clSetKernelArg(kernel, 1, sizeof(int), &in_width);
// clSetKernelArg(kernel, 2, sizeof(int), &size_ch);
// clSetKernelArg(kernel, 3, sizeof(int), &size_block);
// clSetKernelArg(kernel, 4, sizeof(int), &size_batch);
// clSetKernelArg(kernel, 5, sizeof(cl_mem), &input);
// clSetKernelArg(kernel, 6, sizeof(cl_mem), &outputBuffer);
//
// clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
// default_work_size.data(), NULL, 0, NULL, NULL);
}
template class FetchKernel<GPU_CL, float>;
......
......@@ -37,19 +37,19 @@ void ReshapeKernel<GPU_CL, float>::Compute(const ReshapeParam<GPU_CL> &param) {
int dims[4] = {1, 1, 1, 1};
int odims[4] = {1, 1, 1, 1};
for (int i = 0; i < inputDim.size(); i++) {
dims[4-inputDim.size()+i] = inputDim[i];
dims[4 - inputDim.size() + i] = inputDim[i];
}
for (int i = 0; i < outputDim.size(); i++) {
odims[4-outputDim.size()+i] = outputDim[i];
odims[4 - outputDim.size() + i] = outputDim[i];
}
clSetKernelArg(kernel, 2, sizeof(int), dims);
clSetKernelArg(kernel, 3, sizeof(int), dims + 1);
clSetKernelArg(kernel, 4, sizeof(int), dims + 2);
clSetKernelArg(kernel, 5, sizeof(int), dims + 3);
clSetKernelArg(kernel, 6, sizeof(int), odims);
clSetKernelArg(kernel, 7, sizeof(int), odims + 1);
clSetKernelArg(kernel, 8, sizeof(int), odims + 2);
clSetKernelArg(kernel, 9, sizeof(int), odims + 3);
clSetKernelArg(kernel, 2, sizeof(cl_int), &dims);
clSetKernelArg(kernel, 3, sizeof(cl_int), &dims[1]);
clSetKernelArg(kernel, 4, sizeof(cl_int), &dims[2]);
clSetKernelArg(kernel, 5, sizeof(cl_int), &dims[3]);
clSetKernelArg(kernel, 6, sizeof(cl_int), &odims);
clSetKernelArg(kernel, 7, sizeof(cl_int), &odims[1]);
clSetKernelArg(kernel, 8, sizeof(cl_int), &odims[1]);
clSetKernelArg(kernel, 9, sizeof(cl_int), &odims[1]);
const size_t work_size[2] = {output->ImageWidth(), output->ImageHeight()};
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL,
......
......@@ -36,11 +36,14 @@ void SoftmaxKernel<GPU_CL, float>::Compute(const SoftmaxParam<GPU_CL> &param) {
clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
const auto &inputDim = input->dims();
int dims[4] = {inputDim[0], inputDim[1], inputDim[2], inputDim[3]};
clSetKernelArg(kernel, 2, sizeof(int), dims);
clSetKernelArg(kernel, 3, sizeof(int), dims + 1);
clSetKernelArg(kernel, 4, sizeof(int), dims + 2);
clSetKernelArg(kernel, 5, sizeof(int), dims + 3);
int dims[4] = {1, 1, 1, 1};
for (int i = 0; i < inputDim.size(); i++) {
dims[4 - inputDim.size() + i] = inputDim[i];
}
clSetKernelArg(kernel, 2, sizeof(int), &dims);
clSetKernelArg(kernel, 3, sizeof(int), &dims[1]);
clSetKernelArg(kernel, 4, sizeof(int), &dims[2]);
clSetKernelArg(kernel, 5, sizeof(int), &dims[3]);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
......
......@@ -23,7 +23,7 @@ int main() {
// auto isok = paddle_mobile.Load(std::string(g_mobilenet_detect) + "/model",
// std::string(g_mobilenet_detect) + "/params", true);
auto isok = paddle_mobile.Load(g_mobilenet, false);
auto isok = paddle_mobile.Load(g_mobilenet, true);
if (isok) {
auto time2 = paddle_mobile::time();
std::cout << "load cost :" << paddle_mobile::time_diff(time1, time1) << "ms"
......
......@@ -17,7 +17,7 @@ shift
perl -i -pe 's|^\s+#pragma\s+omp|// <TRICKY-CLANG-FORMAT-PRAGMA-FIX> #pragma omp|' "$@"
(
# remove clang format ios_io folder
flist=$(echo "$@" | perl -pe 's|src/ios_io/[^ ]*||')
flist=$(echo "$@" | perl -pe 's|src/io/ios_io/[^ ]*||')
clang-format -i $flist
)
perl -i -pe 's|// <TRICKY-CLANG-FORMAT-PRAGMA-FIX> ||' "$@"
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册