提交 82267790 编写于 作者: Z zp7 提交者: Jiaying Zhao

add reshape2, transpose2, split GPU operator (#1664)

* add reshape2, transpose2, split GPU operator

* fix reshape2op && delete log
上级 52bd9465
......@@ -971,6 +971,23 @@ void Executor<GPU_CL, float>::InitCombineMemory() {
program_.scope->GetCLScpoe()->CommandQueue();
const TensorDesc &desc = var_desc->Tensor_desc();
DDim ddim = cl_image->dims();
bool shouldResize = true;
if (ddim.size() > 4) {
for (int i = 0; i < ddim.size() - 4; ++i) {
if (ddim[i] != 0) {
shouldResize = false;
break;
}
}
if (shouldResize) {
std::vector<int64_t> temp_intput_dims;
temp_intput_dims.reserve(static_cast<size_t>(4));
for (int i = ddim.size() - 4; i < ddim.size(); ++i) {
temp_intput_dims.push_back(ddim[i]);
}
ddim = framework::make_ddim(temp_intput_dims);
}
}
// DDim ddim = make_ddim(desc.Dims());
cl_image->InitEmptyImage(context, command_queue, ddim);
}
......
......@@ -103,7 +103,7 @@ LOAD_OP2(fusion_elementwise_add_relu, CPU, FPGA);
LOAD_FUSION_MATCHER(fusion_elementwise_add_relu);
#endif
#ifdef SPLIT_OP
LOAD_OP1(split, CPU);
LOAD_OP2(split, CPU, GPU_CL);
#endif
#ifdef RESIZE_OP
LOAD_OP1(resize, CPU);
......@@ -116,13 +116,13 @@ LOAD_FUSION_MATCHER(fusion_conv_add_bn_relu);
LOAD_OP2(reshape, CPU, GPU_CL);
#endif
#ifdef RESHAPE2_OP
LOAD_OP1(reshape2, CPU);
LOAD_OP2(reshape2, CPU, GPU_CL);
#endif
#ifdef TRANSPOSE_OP
LOAD_OP2(transpose, CPU, GPU_CL);
#endif
#ifdef TRANSPOSE2_OP
LOAD_OP1(transpose2, CPU);
LOAD_OP2(transpose2, CPU, GPU_CL);
#endif
#ifdef PRIORBOX_OP
LOAD_OP2(prior_box, CPU, GPU_CL);
......
......@@ -18,7 +18,7 @@ __kernel void scale(__read_only image2d_t input,
__write_only image2d_t output,
__private float scale,
__private float bias,
__private float out_width){
__private int out_width){
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef RESHAPE2_OP
#include "operators/kernel/reshape2_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool Reshape2Kernel<GPU_CL, float>::Init(Reshape2Param<GPU_CL> *param) {
this->cl_helper_.AddKernel("reshape", "reshape.cl");
return true;
}
inline framework::DDim ValidateShape(const std::vector<int> shape,
const framework::DDim &in_dims) {
const int64_t in_size = framework::product(in_dims);
// only one dimension can be set to -1, whose size will be automatically
// infered.
const int64_t unk_dim_val = -1;
const int64_t copy_dim_val = 0;
std::vector<int64_t> output_shape(shape.size(), 0);
int64_t capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
PADDLE_MOBILE_ENFORCE(
unk_dim_idx == -1,
"Only one input dimension of Attr(shape) can be unknown.");
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
PADDLE_MOBILE_ENFORCE(
static_cast<int>(i) < in_dims.size(),
"The index of dimension to copy from input shape must be less "
"than the size of input shape.");
} else {
PADDLE_MOBILE_ENFORCE(
shape[i] > 0,
"Each input dimension of Attr(shape) must not be negtive except "
"one unknown dimension.");
}
capacity *= (shape[i] ? shape[i] : in_dims[i]);
output_shape[i] = (shape[i] ? static_cast<int64_t>(shape[i]) : in_dims[i]);
}
if (unk_dim_idx != -1) {
output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size,
"Invalid shape is given.");
} else {
PADDLE_MOBILE_ENFORCE(capacity == in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
}
template <>
void Reshape2Kernel<GPU_CL, float>::Compute(
const Reshape2Param<GPU_CL> &param) {
auto kernel = this->cl_helper_.KernelAt(0);
auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out());
const auto *input = param.InputX();
auto *output = param.Out();
auto input_image = input->GetCLImage();
auto output_image = output->GetCLImage();
const auto &inputDim = input->dims();
const auto &outputDim = output->dims();
int input_dims[4] = {1, 1, 1, 1};
int output_dims[4] = {1, 1, 1, 1};
// 1 1000 1 1
for (int i = 0; i < inputDim.size(); i++) {
input_dims[4 - inputDim.size() + i] = inputDim[i];
}
// 1 1 1 1000
for (int i = 0; i < outputDim.size(); i++) {
output_dims[4 - outputDim.size() + i] = outputDim[i];
}
int out_C = output_dims[1];
int out_H = output_dims[2];
int out_W = output_dims[3];
int in_W = input_dims[3];
int in_H = input_dims[2];
int in_Stride0 = in_W;
int in_Stride1 = input_dims[2] * input_dims[3];
int in_Stride2 = input_dims[1] * input_dims[2] * input_dims[3];
int out_Stride0 = out_W;
int out_Stride1 = out_H * out_W;
int out_Stride2 = out_C * out_H * out_W;
DLOG << "out_C=" << out_C;
DLOG << "out_H=" << out_H;
DLOG << "out_W=" << out_W;
DLOG << "in_W=" << in_W;
DLOG << "default_work_size=" << default_work_size;
DLOG << "in_Stride0=" << in_Stride0;
DLOG << "in_Stride1=" << in_Stride1;
DLOG << "out_Stride0=" << out_Stride0;
DLOG << "out_Stride1=" << out_Stride1;
cl_int status;
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &out_C);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 3, sizeof(int), &out_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 4, sizeof(int), &out_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 5, sizeof(int), &in_W);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 6, sizeof(int), &in_H);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 7, sizeof(int), &in_Stride0);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 8, sizeof(int), &in_Stride1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 9, sizeof(int), &in_Stride2);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 10, sizeof(int), &out_Stride0);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 11, sizeof(int), &out_Stride1);
CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 12, sizeof(int), &out_Stride2);
CL_CHECK_ERRORS(status);
status = clEnqueueNDRangeKernel(
this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
}
template class Reshape2Kernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef SPLIT_OP
#include "operators/kernel/split_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool SplitKernel<GPU_CL, float>::Init(SplitParam<GPU_CL>* param) {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
this->cl_helper_.AddKernel("feed", "feed_kernel.cl");
return true;
}
// Strided numel memory copy from src to dst by the specified axis
//
// For example, for a tensor dims [4, 20, 100], the strieded numel is
// [8000, 2000, 100]
//
// NOTE: The src and dst tensor should have the same elements
// except the specified axis.
template <typename T>
void StridedNumelCopyWithAxis(int64_t axis, T* dst,
const framework::DDim& dst_stride_numel,
const T* src,
const framework::DDim& src_stride_numel,
int64_t size) {
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
PADDLE_MOBILE_ENFORCE(src_stride_numel.size() == dst_stride_numel.size(),
"src and dst tensor should have the same dims size.");
for (int64_t i = 0; i < axis; ++i) {
if (i < axis) {
PADDLE_MOBILE_ENFORCE(src_stride_numel[i] / src_stride_numel[axis] ==
dst_stride_numel[i] / dst_stride_numel[axis],
"src and dst should have the same elements "
"except the specified axis.");
} else if (i == axis) {
continue;
} else {
PADDLE_MOBILE_ENFORCE(src_stride_numel[i] == dst_stride_numel[i],
"src and dst should have the same elements "
"except the specified axis.");
}
}
for (int64_t i = 0; i < before; ++i) {
memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size);
}
}
template <>
void SplitKernel<GPU_CL, float>::Compute(const SplitParam<GPU_CL>& param) {
auto kernel0 = this->cl_helper_.KernelAt(0);
auto kernel1 = this->cl_helper_.KernelAt(1);
auto* input_image = param.InputX();
auto in_stride = framework::stride_numel(input_image->dims());
auto input_dims = input_image->dims();
auto outs_images = param.Outs();
int64_t axis = param.Axis();
Tensor* input_tensor = new Tensor();
input_tensor->Resize(input_image->dims());
input_tensor->mutable_data<float>();
framework::CLImageToTensor(input_image, input_tensor,
this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), kernel0);
size_t input_offset = 0;
for (auto out : outs_images) {
auto out_stride = framework::stride_numel(out->dims());
Tensor* temp_out = new Tensor();
temp_out->Resize(out->dims());
temp_out->mutable_data<float>();
framework::CLImageToTensor(out, temp_out, this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), kernel0);
StridedNumelCopyWithAxis<float>(axis, temp_out->data<float>(), out_stride,
input_tensor->data<float>() + input_offset,
in_stride, out_stride[axis]);
input_offset += out_stride[axis];
out->InitEmptyImage(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), temp_out->dims());
framework::TensorToCLImage(temp_out, out, this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), kernel1);
outs_images.push_back(out);
delete (temp_out);
}
delete (input_tensor);
}
template class SplitKernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef TRANSPOSE2_OP
#include "operators/kernel/transpose2_kernel.h"
namespace paddle_mobile {
namespace operators {
template <>
bool Transpose2Kernel<GPU_CL, float>::Init(Transpose2Param<GPU_CL> *param) {
this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl");
this->cl_helper_.AddKernel("feed", "feed_kernel.cl");
return true;
}
inline bool IsShuffleChannel(const std::vector<int> &axis) {
bool is_shuffle_channel = true;
if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) {
for (int i = 3; i < axis.size(); ++i) {
if (axis[i] != i) {
is_shuffle_channel = false;
break;
}
}
} else {
return false;
}
return is_shuffle_channel;
}
template <typename Dtype>
void ShuffleChannelCompute(const Transpose2Param<GPU_CL> &param,
cl_context context, cl_command_queue commandQueue,
cl_kernel kernel0, cl_kernel kernel1) {
auto axis = param.Axis();
int axis_size = axis.size();
bool shouldResize = true;
int diff_dim = 0;
if (axis_size > 4) {
for (int i = 0; i < axis_size - 4; ++i) {
if (axis[i] != i) {
shouldResize = false;
break;
} else {
diff_dim++;
}
}
if (shouldResize) {
std::vector<int> temp_axis_dims;
temp_axis_dims.reserve(static_cast<size_t>(4));
for (int i = axis_size - 4; i < axis_size; ++i) {
temp_axis_dims.push_back(axis[i] - diff_dim);
}
axis.resize(4);
axis.clear();
axis.insert(axis.begin(), temp_axis_dims.begin(), temp_axis_dims.end());
}
}
auto input = param.InputX();
Tensor *input_tensor = new Tensor();
input_tensor->Resize(input->dims());
input_tensor->mutable_data<float>();
framework::CLImageToTensor(input, input_tensor, context, commandQueue,
kernel0);
const Dtype *input_ptr = input_tensor->data<Dtype>();
auto output = param.Out();
Tensor *output_tensor = new Tensor();
output_tensor->Resize(input->dims());
output_tensor->mutable_data<float>();
Dtype *output_ptr = output_tensor->mutable_data<Dtype>();
// input and output's shape dimension must >= 2 && <= 6.
const framework::DDim &in_dim = input->dims();
const framework::DDim &out_dim = output->dims();
size_t offset = 1;
for (int i = 2; i < axis.size(); ++i) {
offset *= in_dim[i];
}
#pragma omp parallel for collapse(2)
for (int c1 = 0; c1 < out_dim[0]; ++c1) {
for (int c2 = 0; c2 < out_dim[1]; ++c2) {
size_t out_offset = (c1 * out_dim[1] + c2) * offset;
size_t in_offset = (c2 * in_dim[1] + c1) * offset;
memcpy(output_ptr + out_offset, input_ptr + in_offset,
offset * sizeof(Dtype));
}
}
output->InitEmptyImage(context, commandQueue, output_tensor->dims());
framework::TensorToCLImage(output_tensor, output, context, commandQueue,
kernel1);
delete (input_tensor);
delete (output_tensor);
}
template <>
void Transpose2Kernel<GPU_CL, float>::Compute(
const Transpose2Param<GPU_CL> &param) {
auto kernel0 = this->cl_helper_.KernelAt(0);
auto kernel1 = this->cl_helper_.KernelAt(1);
const std::vector<int> &axis = param.Axis();
bool shuffle_channel = IsShuffleChannel(axis);
if (shuffle_channel) {
ShuffleChannelCompute<float>(param, this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue(), kernel0,
kernel1);
} else {
PADDLE_MOBILE_THROW_EXCEPTION("axis not support");
}
}
template class Transpose2Kernel<GPU_CL, float>;
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -1362,7 +1362,7 @@ class Transpose2Param : public OpParam {
axis_ = GetAttr<vector<int>>("axis", attrs);
}
const GType *InputX() const { return input_x_; }
GType *InputX() const { return input_x_; }
GType *Out() const { return out_; }
......@@ -1510,7 +1510,7 @@ class Reshape2Param : public OpParam {
}
}
const GType *InputX() const { return input_x_; }
GType *InputX() const { return input_x_; }
const GType *InputShape() const { return input_shape_; }
......@@ -2807,7 +2807,7 @@ class SplitParam : public OpParam {
// out_ts_.push_back(*scope.FindVar(outs_[i])->GetMutable());
// }
}
const GType *InputX() const { return input_x_; }
GType *InputX() const { return input_x_; }
std::vector<GType *> Outs() const { return outs_; }
int Axis() const { return axis; }
int Num() const { return num; }
......
......@@ -24,8 +24,52 @@ template <typename Dtype, typename T>
void Reshape2Op<Dtype, T>::InferShape() const {
auto &shape = this->param_.Shape();
auto input_x_dims = this->param_.InputX()->dims();
#ifdef PADDLE_MOBILE_CL
auto input_dim_size = input_x_dims.size();
bool shouldResize = true;
if (input_dim_size > 4) {
for (int i = 0; i < input_dim_size - 4; ++i) {
if (input_x_dims[i] != 0 && input_x_dims[i] != 1) {
shouldResize = false;
break;
}
}
if (shouldResize) {
std::vector<int64_t> temp_intput_dims;
temp_intput_dims.reserve(static_cast<size_t>(4));
for (int i = input_dim_size - 4; i < input_dim_size; ++i) {
temp_intput_dims.push_back(input_x_dims[i]);
}
framework::DDim temp_ddim = framework::make_ddim(temp_intput_dims);
this->param_.InputX()->Resize(temp_ddim);
input_x_dims = this->param_.InputX()->dims();
}
}
#endif
auto out_dims = ValidateShape(shape, input_x_dims);
this->param_.Out()->Resize(out_dims);
#ifdef PADDLE_MOBILE_CL
input_x_dims = this->param_.InputX()->dims();
shouldResize = true;
if (out_dims.size() > 4) {
for (int i = 0; i < out_dims.size() - 4; ++i) {
if (out_dims[i] != 0 && out_dims[i] != 1) {
shouldResize = false;
break;
}
}
if (shouldResize) {
std::vector<int64_t> temp_output_dims;
temp_output_dims.reserve(static_cast<size_t>(4));
for (int i = out_dims.size() - 4; i < out_dims.size(); ++i) {
temp_output_dims.push_back(out_dims[i]);
}
framework::DDim temp_ddim = framework::make_ddim(temp_output_dims);
this->param_.Out()->Resize(temp_ddim);
}
}
#endif
std::vector<int64_t> xshape_dims(input_x_dims.size() + 1, 0);
for (int i = 0; i < input_x_dims.size(); ++i) {
xshape_dims[i + 1] = input_x_dims[i];
......@@ -40,6 +84,9 @@ namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
REGISTER_OPERATOR_CPU(reshape2, ops::Reshape2Op);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(reshape2, ops::Reshape2Op);
#endif
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(reshape2, ops::Reshape2Op);
#endif
......
......@@ -86,5 +86,8 @@ REGISTER_OPERATOR_CPU(split, ops::SplitOp);
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(split, ops::SplitOp);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(split, ops::SplitOp);
#endif
#endif // SPLIT_OP
......@@ -29,6 +29,55 @@ void Transpose2Op<Dtype, T>::InferShape() const {
size_t x_dims_size = input_x_dims.size();
size_t axis_size = axis.size();
#ifdef PADDLE_MOBILE_CL
bool shouldResize = true;
int diff_dim = 0;
if (axis_size > 4) {
for (int i = 0; i < axis_size - 4; ++i) {
if (axis[i] != i) {
shouldResize = false;
break;
} else {
diff_dim++;
}
}
if (shouldResize) {
std::vector<int> temp_axis_dims;
temp_axis_dims.reserve(static_cast<size_t>(4));
for (int i = axis_size - 4; i < axis_size; ++i) {
temp_axis_dims.push_back(axis[i] - diff_dim);
}
axis.resize(4);
axis.clear();
axis.insert(axis.begin(), temp_axis_dims.begin(), temp_axis_dims.end());
}
}
auto input_dim_size = input_x_dims.size();
shouldResize = true;
if (input_dim_size > 4) {
for (int i = 0; i < input_dim_size - 4; ++i) {
if (input_x_dims[i] != 0 && input_x_dims[i] != 1) {
shouldResize = false;
break;
}
}
if (shouldResize) {
std::vector<int64_t> temp_intput_dims;
temp_intput_dims.reserve(static_cast<size_t>(4));
for (int i = input_dim_size - 4; i < input_dim_size; ++i) {
temp_intput_dims.push_back(input_x_dims[i]);
}
framework::DDim temp_ddim = framework::make_ddim(temp_intput_dims);
this->param_.InputX()->Resize(temp_ddim);
}
}
axis_size = axis.size();
input_x_dims = this->param_.InputX()->dims();
x_dims_size = input_x_dims.size();
#endif
PADDLE_MOBILE_ENFORCE((x_dims_size == axis_size),
"input_dims must "
"be equal to the axis_size. ")
......@@ -63,5 +112,7 @@ REGISTER_OPERATOR_CPU(transpose2, ops::Transpose2Op);
#ifdef PADDLE_MOBILE_FPGA
REGISTER_OPERATOR_FPGA(transpose2, ops::Transpose2Op);
#endif
#ifdef PADDLE_MOBILE_CL
REGISTER_OPERATOR_CL(transpose2, ops::Transpose2Op);
#endif
#endif // TRANSPOSE_OP
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册