diff --git a/src/framework/cl/cl_image.h b/src/framework/cl/cl_image.h index 5bcd09984f1985e0465087e610bd528c4a689032..d802e4e6dc24c5beee959a808a34da05c5ced0a7 100644 --- a/src/framework/cl/cl_image.h +++ b/src/framework/cl/cl_image.h @@ -32,8 +32,13 @@ class CLImage { } - const DDim &TensorDim() { - return tensorDims_; + inline CLImage &Resize(const DDim &dims) { + tensorDims_ = dims; + return *this; + } + + const DDim &dims() const { + return DDim(); } std::vector DefaultWorkSize() { diff --git a/src/framework/op_registry.h b/src/framework/op_registry.h index a5a2c2d23cea88f9697c5230d582cbb795622983..657a1f88effcb3db6357994d531609a94f95bfe3 100644 --- a/src/framework/op_registry.h +++ b/src/framework/op_registry.h @@ -122,6 +122,9 @@ class OpRegistry { #define REGISTER_OPERATOR_FPGA(op_type, op_class) \ REGISTER_OPERATOR(op_type, op_class, fpga, paddle_mobile::FPGA); +#define REGISTER_OPERATOR_CL(op_type, op_class) \ + REGISTER_OPERATOR(op_type, op_class, cl, paddle_mobile::GPU_CL); + #define USE_OP(op_type, device_name) \ extern int TouchOpRegistrar_##op_type##_##device_name(); \ static int use_op_itself_##op_type##_##device_name __attribute__((unused)) = \ @@ -133,5 +136,7 @@ class OpRegistry { #define USE_OP_FPGA(op_type) USE_OP(op_type, fpga); +#define USE_OP_CL(op_type) USE_OP(op_type, cl); + } // namespace framework } // namespace paddle_mobile diff --git a/src/operators/conv_op.cpp b/src/operators/conv_op.cpp index c4601995219b32db75f22c7c2ed959e18af85f36..2c70f42f56530c2d21252d6b51c228e7c49ca8bf 100644 --- a/src/operators/conv_op.cpp +++ b/src/operators/conv_op.cpp @@ -62,4 +62,8 @@ REGISTER_OPERATOR_MALI_GPU(conv2d, ops::ConvOp); REGISTER_OPERATOR_FPGA(conv2d, ops::ConvOp); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(conv2d, ops::ConvOp); +#endif + #endif diff --git a/src/operators/conv_op.h b/src/operators/conv_op.h index 299c12cdeaed25c114be5c7d3dc4fe74044f9298..719eaa561a712ab5a4ade49ba978b6482fa4dd70 100644 --- a/src/operators/conv_op.h +++ b/src/operators/conv_op.h @@ -45,11 +45,17 @@ class ConvOp : public framework::OperatorWithKernel< #ifdef PADDLE_MOBILE_CPU USE_OP_CPU(conv2d); #endif + #ifdef PADDLE_MOBILE_MALI_GPU USE_OP_MALI_GPU(conv2d); #endif + #ifdef PADDLE_MOBILE_FPGA USE_OP_FPGA(conv2d); #endif +#ifdef PADDLE_MOBILE_CL +USE_OP_CL(conv2d); +#endif + #endif