From e4dabf84461cc6d54145131db37577a33a71c67e Mon Sep 17 00:00:00 2001 From: liuruilong Date: Sat, 29 Sep 2018 15:21:12 +0800 Subject: [PATCH] update cl image --- src/framework/cl/cl_image.h | 9 +++++++-- src/framework/op_registry.h | 5 +++++ src/operators/conv_op.cpp | 4 ++++ src/operators/conv_op.h | 6 ++++++ 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/framework/cl/cl_image.h b/src/framework/cl/cl_image.h index 5bcd09984f..d802e4e6dc 100644 --- a/src/framework/cl/cl_image.h +++ b/src/framework/cl/cl_image.h @@ -32,8 +32,13 @@ class CLImage { } - const DDim &TensorDim() { - return tensorDims_; + inline CLImage &Resize(const DDim &dims) { + tensorDims_ = dims; + return *this; + } + + const DDim &dims() const { + return DDim(); } std::vector DefaultWorkSize() { diff --git a/src/framework/op_registry.h b/src/framework/op_registry.h index a5a2c2d23c..657a1f88ef 100644 --- a/src/framework/op_registry.h +++ b/src/framework/op_registry.h @@ -122,6 +122,9 @@ class OpRegistry { #define REGISTER_OPERATOR_FPGA(op_type, op_class) \ REGISTER_OPERATOR(op_type, op_class, fpga, paddle_mobile::FPGA); +#define REGISTER_OPERATOR_CL(op_type, op_class) \ + REGISTER_OPERATOR(op_type, op_class, cl, paddle_mobile::GPU_CL); + #define USE_OP(op_type, device_name) \ extern int TouchOpRegistrar_##op_type##_##device_name(); \ static int use_op_itself_##op_type##_##device_name __attribute__((unused)) = \ @@ -133,5 +136,7 @@ class OpRegistry { #define USE_OP_FPGA(op_type) USE_OP(op_type, fpga); +#define USE_OP_CL(op_type) USE_OP(op_type, cl); + } // namespace framework } // namespace paddle_mobile diff --git a/src/operators/conv_op.cpp b/src/operators/conv_op.cpp index c460199521..2c70f42f56 100644 --- a/src/operators/conv_op.cpp +++ b/src/operators/conv_op.cpp @@ -62,4 +62,8 @@ REGISTER_OPERATOR_MALI_GPU(conv2d, ops::ConvOp); REGISTER_OPERATOR_FPGA(conv2d, ops::ConvOp); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(conv2d, ops::ConvOp); +#endif + #endif diff --git a/src/operators/conv_op.h b/src/operators/conv_op.h index 299c12cdea..719eaa561a 100644 --- a/src/operators/conv_op.h +++ b/src/operators/conv_op.h @@ -45,11 +45,17 @@ class ConvOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU USE_OP_CPU(conv2d); #endif + #ifdef PADDLE_MOBILE_MALI_GPU USE_OP_MALI_GPU(conv2d); #endif + #ifdef PADDLE_MOBILE_FPGA USE_OP_FPGA(conv2d); #endif +#ifdef PADDLE_MOBILE_CL +USE_OP_CL(conv2d); +#endif + #endif -- GitLab