From a9f985d40c69b5b5757531d281e28c869db62289 Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Tue, 9 Oct 2018 16:25:51 +0800 Subject: [PATCH] add cl kernel --- src/operators/fetch_op.cpp | 18 +++++++++- src/operators/fetch_op.h | 20 +++++------ src/operators/kernel/cl/batchnorm_kernel.cpp | 36 ++++++++++++++++++++ src/operators/kernel/cl/fetch_kernel.cpp | 31 +++++++++++++++++ src/operators/kernel/cl/pool_kernel.cpp | 35 +++++++++++++++++++ src/operators/kernel/fetch_kernel.h | 34 ++++++++++++++++++ 6 files changed, 162 insertions(+), 12 deletions(-) create mode 100644 src/operators/kernel/cl/batchnorm_kernel.cpp create mode 100644 src/operators/kernel/cl/fetch_kernel.cpp create mode 100644 src/operators/kernel/cl/pool_kernel.cpp create mode 100644 src/operators/kernel/fetch_kernel.h diff --git a/src/operators/fetch_op.cpp b/src/operators/fetch_op.cpp index 30cddceaa4..cc96934cad 100644 --- a/src/operators/fetch_op.cpp +++ b/src/operators/fetch_op.cpp @@ -14,7 +14,23 @@ limitations under the License. */ #include "fetch_op.h" namespace paddle_mobile { -namespace operators {} +namespace operators { + +template +void FetchOp::InferShape() const { + auto x_dims = this->param_.InputX()->dims(); + this->param_.Out()->Resize(x_dims); +} + +template +void FetchOp::RunImpl() { +#ifdef PADDLE_MOBILE_CL + this->kernel_.Compute(this->param_); +#else + this->param_.Out()->ShareDataWith(*(this->param_.InputX())); +#endif +} +} // namespace operators } // namespace paddle_mobile namespace ops = paddle_mobile::operators; diff --git a/src/operators/fetch_op.h b/src/operators/fetch_op.h index 959beff632..708686c887 100644 --- a/src/operators/fetch_op.h +++ b/src/operators/fetch_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "framework/operator.h" +#include "operators/kernel/fetch_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { @@ -23,25 +24,22 @@ namespace operators { using std::string; template -class FetchOp : public framework::OperatorBase { +class FetchOp + : public framework::OperatorWithKernel, + FetchKernel> { public: FetchOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) - : framework::OperatorBase(type, inputs, outputs, attrs, - scope), - param_(inputs, outputs, attrs, *scope) {} - void RunImpl() { param_.Out()->ShareDataWith(*param_.InputX()); } + : framework::OperatorWithKernel, + FetchKernel>( + type, inputs, outputs, attrs, scope) {} - void Init() {} + void InferShape() const override; - void InferShape() const { - auto x_dims = param_.InputX()->dims(); - param_.Out()->Resize(x_dims); - } + void RunImpl() override; protected: - FetchParam param_; }; } // namespace operators diff --git a/src/operators/kernel/cl/batchnorm_kernel.cpp b/src/operators/kernel/cl/batchnorm_kernel.cpp new file mode 100644 index 0000000000..a096fae81d --- /dev/null +++ b/src/operators/kernel/cl/batchnorm_kernel.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BATCHNORM_OP + +#include "operators/kernel/batchnorm_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool BatchNormKernel::Init(BatchNormParam *param) { + return true; +} + +template <> +void BatchNormKernel::Compute( + const BatchNormParam ¶m) {} + +template class BatchNormKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/cl/fetch_kernel.cpp b/src/operators/kernel/cl/fetch_kernel.cpp new file mode 100644 index 0000000000..d10bfe7a4b --- /dev/null +++ b/src/operators/kernel/cl/fetch_kernel.cpp @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/kernel/fetch_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool FetchKernel::Init(FetchParam *param) { + return true; +} + +template <> +void FetchKernel::Compute(const FetchParam ¶m) {} + +template class FetchKernel; + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/cl/pool_kernel.cpp b/src/operators/kernel/cl/pool_kernel.cpp new file mode 100644 index 0000000000..c24a1babf1 --- /dev/null +++ b/src/operators/kernel/cl/pool_kernel.cpp @@ -0,0 +1,35 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef POOL_OP + +#include "operators/kernel/pool_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool PoolKernel::Init(PoolParam *param) { + return true; +} + +template <> +void PoolKernel::Compute(const PoolParam ¶m) {} + +template class PoolKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/fetch_kernel.h b/src/operators/kernel/fetch_kernel.h new file mode 100644 index 0000000000..d9ed91855d --- /dev/null +++ b/src/operators/kernel/fetch_kernel.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using namespace framework; + +template +class FetchKernel + : public framework::OpKernelBase> { + public: + void Compute(const FetchParam ¶m); + bool Init(FetchParam *param); +}; + +} // namespace operators +} // namespace paddle_mobile -- GitLab