未验证 提交 526c446d 编写于 作者: Y Yanzhan Yang 提交者: GitHub

fix nearest_interp, concat by channel, leaky_relu. (#1775)

上级 a63d9e9d
......@@ -32,10 +32,6 @@ void ConcatOp<Dtype, T>::InferShape() const {
inputs_dims.push_back(inputs[i]->dims());
}
auto axis = static_cast<size_t>(this->param_.Axis()) -
(this->param_.original_output_dims_size_ -
this->param_.Out()->dims().size());
if (n == 1) {
DLOG << "Warning: concat op have only one input, "
"may waste memory";
......@@ -43,6 +39,8 @@ void ConcatOp<Dtype, T>::InferShape() const {
/// add all dim[axis] and check other dims if equal.
auto out_dims = inputs_dims[0];
auto axis = static_cast<size_t>(this->param_.Axis()) -
(this->param_.original_output_dims_size_ - out_dims.size());
int in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) {
......
......@@ -22,20 +22,61 @@ __kernel void concatByCWith2Inputs(__read_only image2d_t input_image_0,
__write_only image2d_t output_image,
__private const int out_C,
__private const int out_W) {
// const int in_c = get_global_id(0);
// const int in_w = get_global_id(1);
// const int in_nh = get_global_id(2);
//
// int2 input_pos ;
// input_pos.x = in_c * out_W + in_w;
// input_pos.y = in_nh;
// const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
// CLK_ADDRESS_CLAMP |
// CLK_FILTER_NEAREST;
// half4 input;
// input = read_imageh(input_image, sampler,input_pos);
//
// write_imageh(output_image, input_pos, input);
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
half4 output_data;
for (int i = 0; i < 4; i++) {
int c = out_c * 4 + i;
if (c >= out_C) {
break;
}
int c_in;
half4 input_data;
if (c < C_0) {
c_in = c;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_0, sampler, input_pos);
} else {
c_in = c - C_0;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_1, sampler, input_pos);
}
int value_offset = c_in % 4;
float value;
if (value_offset == 0) {
value = input_data.x;
} else if (value_offset == 1) {
value = input_data.y;
} else if (value_offset == 2) {
value = input_data.z;
} else if (value_offset == 3) {
value = input_data.w;
}
if (i == 0) {
output_data.x = value;
} else if (i == 1) {
output_data.y = value;
} else if (i == 2) {
output_data.z = value;
} else if (i == 3) {
output_data.w = value;
}
}
write_imageh(output_image, output_pos, output_data);
}
__kernel void concatByCWith3Inputs(__read_only image2d_t input_image_0,
......@@ -47,20 +88,67 @@ __kernel void concatByCWith3Inputs(__read_only image2d_t input_image_0,
__write_only image2d_t output_image,
__private const int out_C,
__private const int out_W) {
// const int in_c = get_global_id(0);
// const int in_w = get_global_id(1);
// const int in_nh = get_global_id(2);
//
// int2 input_pos ;
// input_pos.x = in_c * out_W + in_w;
// input_pos.y = in_nh;
// const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
// CLK_ADDRESS_CLAMP |
// CLK_FILTER_NEAREST;
// half4 input;
// input = read_imageh(input_image, sampler,input_pos);
//
// write_imageh(output_image, input_pos, input);
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
half4 output_data;
for (int i = 0; i < 4; i++) {
int c = out_c * 4 + i;
if (c >= out_C) {
break;
}
int c_in;
half4 input_data;
if (c < C_0) {
c_in = c;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_0, sampler, input_pos);
} else if (c < C_0 + C_1) {
c_in = c - C_0;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_1, sampler, input_pos);
} else {
c_in = c - C_0 - C_1;
int2 input_pos;
input_pos.x = (c_in / 4) * out_W + out_w;
input_pos.y = out_nh;
input_data = read_imageh(input_image_2, sampler, input_pos);
}
int value_offset = c_in % 4;
float value;
if (value_offset == 0) {
value = input_data.x;
} else if (value_offset == 1) {
value = input_data.y;
} else if (value_offset == 2) {
value = input_data.z;
} else if (value_offset == 3) {
value = input_data.w;
}
if (i == 0) {
output_data.x = value;
} else if (i == 1) {
output_data.y = value;
} else if (i == 2) {
output_data.z = value;
} else if (i == 3) {
output_data.w = value;
}
}
write_imageh(output_image, output_pos, output_data);
}
__kernel void concatByH(__read_only image2d_t input_image,
......
......@@ -30,9 +30,9 @@ __kernel void leakyrelu(__read_only image2d_t input,
half4 output_data;
output_data.x = max((float)(in.x), (float)(alpha * (in.x)));
output_data.y = max((float)(in.x), (float)(alpha * (in.y)));
output_data.z = max((float)(in.x), (float)(alpha * (in.z)));
output_data.w = max((float)(in.x), (float)(alpha * (in.w)));
output_data.y = max((float)(in.y), (float)(alpha * (in.y)));
output_data.z = max((float)(in.z), (float)(alpha * (in.z)));
output_data.w = max((float)(in.w), (float)(alpha * (in.w)));
write_imageh(output, (int2)(input_pos.x, input_pos.y), output_data);
}
......@@ -15,19 +15,23 @@ limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void nearest_interp(__read_only image2d_t input, __write_only image2d_t output,
__private const float scale_h, __private const float scale_w,
__private const int dims_w){
__private const int in_dims_h, __private const int out_dims_h,
__private const int in_dims_w, __private const int out_dims_w) {
const int c = get_global_id(0);
const int w = get_global_id(1);
const int nh = get_global_id(2);
int2 output_pos;
output_pos.x = c * dims_w + w;
output_pos.x = c * out_dims_w + w;
output_pos.y = nh;
int out_n = nh / out_dims_h;
int out_h = nh % out_dims_h;
int2 input_pos;
input_pos.x = c * in_dims_w + w / scale_w;
input_pos.y = out_n * in_dims_h + out_h / scale_h;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
// uint x = (uint)(output_pos.x / scale_w);
// uint y = (uint)(output_pos.y / scale_h);
// half4 input_data = read_imageh(input, sampler, (int2)(x, y));
// write_imageh(output, (int2)(output_pos.x , output_pos.y ), input_data);
half4 input_data = read_imageh(input, sampler, (int2)(input_pos.x, input_pos.y));
write_imageh(output, (int2)(output_pos.x , output_pos.y), input_data);
}
......@@ -38,7 +38,10 @@ void NearestInterpolationKernel<GPU_CL, float>::Compute(
cl_mem output_image = output->GetCLImage();
float scale_h = output->dims()[2] / input->dims()[2];
float scale_w = output->dims()[3] / input->dims()[3];
int in_dims_w = output->dims()[3];
int in_dims_h = input->dims()[2];
int out_dims_h = output->dims()[2];
int in_dims_w = input->dims()[3];
int out_dims_w = output->dims()[3];
cl_int status;
......@@ -50,7 +53,13 @@ void NearestInterpolationKernel<GPU_CL, float>::Compute(
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 3, sizeof(float), &scale_w);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 4, sizeof(int), &in_dims_w);
status = clSetKernelArg(kernel, 4, sizeof(int), &in_dims_h);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 5, sizeof(int), &out_dims_h);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 6, sizeof(int), &in_dims_w);
CL_CHECK_ERRORS(status)
status = clSetKernelArg(kernel, 7, sizeof(int), &out_dims_w);
CL_CHECK_ERRORS(status)
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
......
......@@ -686,7 +686,7 @@ class ConcatParam : public OpParam {
inputs_ = InputMultiFrom<GType>(inputs, *scope);
out_ = OutFrom<GType>(outputs, *scope);
axis_ = GetAttr<int>("axis", attrs);
original_output_dims_size_ = out_->dims().size();
original_output_dims_size_ = inputs_[0]->dims().size();
}
vector<GType *> Inputs() const { return inputs_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册