未验证 提交 23b48ed6 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #1116 from dolphin8/opencl

Opencl
......@@ -24,6 +24,34 @@ __kernel void relu(__read_only image2d_t input,
CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, (int2)(x, y));
in = max((half4)(0.0f,0.0f,0.0f,0.0f), in);
in = max((half4)(0.0f, 0.0f, 0.0f, 0.0f), in);
write_imageh(output, (int2)(x, y), in);
}
__kernel void relu_p0(__read_only image2d_t input,
__write_only image2d_t output){
const int x = get_global_id(0);
const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, (int2)(x, y));
in = max((half4)(0.0f, 0.0f, 0.0f, 0.0f), in);
write_imageh(output, (int2)(x, y), in);
}
__kernel void relu_p1(__read_only image2d_t input,
__write_only image2d_t output){
const int x = get_global_id(0);
const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, (int2)(x, y));
write_imageh(output, (int2)(x, y), in);
}
\ No newline at end of file
......@@ -20,23 +20,33 @@ namespace operators {
template <>
bool ReluKernel<GPU_CL, float>::Init(ReluParam<GPU_CL>* param) {
// this->cl_helper_.AddKernel("relu", "relu.cl");
this->cl_helper_.AddKernel("relu", "relu.cl");
this->cl_helper_.AddKernel("relu_p0", "relu.cl");
this->cl_helper_.AddKernel("relu_p1", "relu.cl");
const auto dim = const_cast<framework::CLImage*>(param->InputX())->ImageDims();
param->getMidImage().InitEmptyImage(this->cl_helper_.CLContext(), dim);
return true;
}
template <>
void ReluKernel<GPU_CL, float>::Compute(const ReluParam<GPU_CL>& param) {
// auto kernel = this->cl_helper_.KernelAt(0);
// const auto* input = param.InputX();
// auto* output = param.Out();
// auto default_work_size = this->cl_helper_.DefaultWorkSize(*output);
// auto inputImage = input->GetCLImage();
// auto outputImage = output->GetCLImage();
// clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
// clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
// const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()};
// clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
// work_size, NULL, 0, NULL, NULL);
auto kernel_p0 = this->cl_helper_.KernelAt(1);
auto kernel_p1 = this->cl_helper_.KernelAt(2);
const auto* input = param.InputX();
auto* output = param.Out();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*output);
auto inputImage = input->GetCLImage();
auto outputImage = output->GetCLImage();
auto tImage = const_cast<ReluParam<GPU_CL>&>(param).getMidImage().GetCLImage();
clSetKernelArg(kernel_p0, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel_p0, 0, sizeof(cl_mem), &tImage);
clSetKernelArg(kernel_p1, 0, sizeof(cl_mem), &tImage);
clSetKernelArg(kernel_p1, 1, sizeof(cl_mem), &outputImage);
const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()};
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel_p0, 3, NULL,
work_size, NULL, 0, NULL, NULL);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel_p1, 3, NULL,
work_size, NULL, 0, NULL, NULL);
}
template class ReluKernel<GPU_CL, float>;
......
......@@ -1229,12 +1229,12 @@ class ResizeParam : public OpParam {
* @b op 层实例化好这个 param 传递给 kernel 层使用
* */
template <typename Dtype>
class ReluParam : public OpParam {
class ReluParamBase : public OpParam {
typedef typename DtypeTensorTrait<Dtype>::gtype GType;
typedef typename DtypeTensorTrait<Dtype>::rtype RType;
public:
ReluParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
ReluParamBase(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
......@@ -1248,6 +1248,24 @@ class ReluParam : public OpParam {
RType *input_x_;
RType *out_;
};
template <typename Dtype>
class ReluParam : public ReluParamBase<Dtype> {
public:
using ReluParamBase<Dtype>::ReluParamBase;
};
template <>
class ReluParam<GPU_CL> : public ReluParamBase<GPU_CL> {
public:
using ReluParamBase<GPU_CL>::ReluParamBase;
framework::CLImage& getMidImage() {
return midImage;
}
private:
framework::CLImage midImage;
};
#endif
#ifdef PRELU_OP
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册