提交 136c9f32 编写于 作者: L liuruilong

add kernel init method

上级 6543b2a5
...@@ -63,6 +63,7 @@ class OperatorBase { ...@@ -63,6 +63,7 @@ class OperatorBase {
std::vector<string> GetOutKeys() const; std::vector<string> GetOutKeys() const;
virtual void RunImpl() const = 0; virtual void RunImpl() const = 0;
virtual void Init() const = 0;
/* /*
* @b op 运算所需的输入, 如上一层的输出结果、卷积核 * @b op 运算所需的输入, 如上一层的输出结果、卷积核
* */ * */
...@@ -111,14 +112,18 @@ class OperatorWithKernel : public OperatorBase<Dtype> { ...@@ -111,14 +112,18 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope), : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
param_(inputs, outputs, attrs, *scope) { param_(inputs, outputs, attrs, *scope) {
PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed",
this->type_.c_str());
} }
virtual void RunImpl() const { this->kernel_.Compute(this->param_); } virtual void RunImpl() const { this->kernel_.Compute(this->param_); }
virtual void InferShape() const = 0; virtual void InferShape() const = 0;
void Init() const {
PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed",
this->type_.c_str());
}
protected: protected:
KernelType kernel_; KernelType kernel_;
ParamType param_; ParamType param_;
......
...@@ -198,6 +198,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size, ...@@ -198,6 +198,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
} else { } else {
InitMemory(); InitMemory();
} }
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
for (const auto &op: ops) {
op->Init();
}
} }
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
...@@ -416,6 +423,8 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict( ...@@ -416,6 +423,8 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
clock_gettime(CLOCK_MONOTONIC, &ts); clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif #endif
// to Run
ops[i]->Run(); ops[i]->Run();
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
clock_gettime(CLOCK_MONOTONIC, &ts); clock_gettime(CLOCK_MONOTONIC, &ts);
......
...@@ -32,6 +32,8 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -32,6 +32,8 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {} param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void InferShape() const { void InferShape() const {
auto out_dims = param_.Out()->dims(); auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize(); out_dims[0] = param_.BatchSize();
......
...@@ -33,6 +33,8 @@ class FetchOp : public framework::OperatorBase<DeviceType> { ...@@ -33,6 +33,8 @@ class FetchOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {} param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void InferShape() const { void InferShape() const {
auto x_dims = param_.InputX()->dims(); auto x_dims = param_.InputX()->dims();
param_.Out()->Resize(x_dims); param_.Out()->Resize(x_dims);
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool BatchNormKernel<CPU, float>::Init(const BatchNormParam &para) const {
return true;
}
template <> template <>
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const { void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -109,6 +109,11 @@ void DecodeCenterSize(const framework::Tensor& target_box, ...@@ -109,6 +109,11 @@ void DecodeCenterSize(const framework::Tensor& target_box,
} }
} }
template <>
bool BoxCoderKernel<CPU, float>::Init(const BoxCoderParam &para) const {
return true;
}
template <> template <>
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam& param) const { void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam& param) const {
const auto* input_priorbox = param.InputPriorBox(); const auto* input_priorbox = param.InputPriorBox();
......
...@@ -52,6 +52,11 @@ class ConcatFunctor { ...@@ -52,6 +52,11 @@ class ConcatFunctor {
} }
}; };
template <>
bool ConcatKernel<CPU, float>::Init(const ConcatParam &para) const {
return true;
}
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const { void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
auto inputs = param.Inputs(); auto inputs = param.Inputs();
......
...@@ -18,6 +18,11 @@ limitations under the License. */ ...@@ -18,6 +18,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvAddKernel<CPU, float>::Init(const FusionConvAddParam &para) const {
return true;
}
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const { void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvAddReluKernel<CPU, float>::Init(const FusionConvAddReluParam &para) const {
return true;
}
template <> template <>
void ConvAddReluKernel<CPU, float>::Compute( void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam &param) const { const FusionConvAddReluParam &param) const {
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void ConvKernel<CPU, float>::Compute(const ConvParam &param) const { void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
......
...@@ -20,6 +20,11 @@ limitations under the License. */ ...@@ -20,6 +20,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
LOG(kLOG_DEBUG) << param; LOG(kLOG_DEBUG) << param;
......
...@@ -26,6 +26,11 @@ struct AddFunctor { ...@@ -26,6 +26,11 @@ struct AddFunctor {
inline T operator()(T a, T b) const { return a + b; } inline T operator()(T a, T b) const { return a + b; }
}; };
template <>
bool ElementwiseAddKernel<CPU, float>::Init(const ElementwiseAddParam &para) const {
return true;
}
template <> template <>
void ElementwiseAddKernel<CPU, float>::Compute( void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam &param) const { const ElementwiseAddParam &param) const {
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool FusionFcKernel<CPU, float>::Init(const FusionFcParam &para) const {
return true;
}
template <> template <>
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const { void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool LrnKernel<CPU, float>::Init(const LrnParam &para) const {
return true;
}
template <> template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const { void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool MulKernel<CPU, float>::Init(const MulParam &para) const {
return true;
}
template <> template <>
void MulKernel<CPU, float>::Compute(const MulParam &param) const { void MulKernel<CPU, float>::Compute(const MulParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -205,6 +205,11 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, ...@@ -205,6 +205,11 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
} }
} }
template <>
bool MultiClassNMSKernel<CPU, float>::Init(const MultiClassNMSParam &para) const {
return true;
}
template <> template <>
void MultiClassNMSKernel<CPU, float>::Compute( void MultiClassNMSKernel<CPU, float>::Compute(
const MultiClassNMSParam& param) const { const MultiClassNMSParam& param) const {
......
...@@ -35,6 +35,11 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize, ...@@ -35,6 +35,11 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
} }
} }
template <>
bool PoolKernel<CPU, float>::Init(const PoolParam &para) const {
return true;
}
template <> template <>
void PoolKernel<CPU, float>::Compute(const PoolParam &param) const { void PoolKernel<CPU, float>::Compute(const PoolParam &param) const {
const Tensor *in_x = param.Input(); const Tensor *in_x = param.Input();
......
...@@ -28,6 +28,11 @@ struct ClipFunctor { ...@@ -28,6 +28,11 @@ struct ClipFunctor {
} }
}; };
template <>
bool PriorBoxKernel<CPU, float>::Init(const PriorBoxParam &para) const {
return true;
}
template <> template <>
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const { void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const {
const auto *input_ = param.Input(); const auto *input_ = param.Input();
......
...@@ -27,6 +27,11 @@ struct ReluFunctor { ...@@ -27,6 +27,11 @@ struct ReluFunctor {
inline T operator()(T in) const { return in > 0 ? in : 0; } inline T operator()(T in) const { return in > 0 ? in : 0; }
}; };
template <>
bool ReluKernel<CPU, float>::Init(const ReluParam &para) const {
return true;
}
/* /*
* @b 特化到具体平台的实现, param 从 op 层传入 * @b 特化到具体平台的实现, param 从 op 层传入
* */ * */
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ReshapeKernel<CPU, float>::Init(const ReshapeParam &para) const {
return true;
}
template <> template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const { void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
......
...@@ -71,6 +71,11 @@ void sigmoid(const Tensor *X, Tensor *Y) { ...@@ -71,6 +71,11 @@ void sigmoid(const Tensor *X, Tensor *Y) {
#endif #endif
} }
template <>
bool SigmoidKernel<CPU, float>::Init(const SigmoidParam &para) const {
return true;
}
template <> template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const { void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const {
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.InputX();
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool SoftmaxKernel<CPU, float>::Init(const SoftmaxParam &para) const {
return true;
}
template <> template <>
void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const { void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const {
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.InputX();
......
...@@ -34,6 +34,11 @@ namespace operators { ...@@ -34,6 +34,11 @@ namespace operators {
// } // }
// } // }
template <>
bool TransposeKernel<CPU, float>::Init(const TransposeParam &para) const {
return true;
}
template <> template <>
void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const { void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
const auto* input_x = param.InputX(); const auto* input_x = param.InputX();
......
...@@ -29,6 +29,7 @@ class BatchNormKernel ...@@ -29,6 +29,7 @@ class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> { : public framework::OpKernelBase<DeviceType, BatchNormParam> {
public: public:
void Compute(const BatchNormParam &param) const; void Compute(const BatchNormParam &param) const;
bool Init(const BatchNormParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -30,6 +30,7 @@ class BoxCoderKernel ...@@ -30,6 +30,7 @@ class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> { : public framework::OpKernelBase<DeviceType, BoxCoderParam> {
public: public:
void Compute(const BoxCoderParam& param) const; void Compute(const BoxCoderParam& param) const;
bool Init(const BoxCoderParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T> ...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> { class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
public: public:
void Compute(const ConcatParam &param) const; void Compute(const ConcatParam &param) const;
bool Init(const ConcatParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -38,6 +38,7 @@ template <typename DeviceType, typename T> ...@@ -38,6 +38,7 @@ template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> { class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> {
public: public:
void Compute(const FusionConvAddParam &param) const; void Compute(const FusionConvAddParam &param) const;
bool Init(const FusionConvAddParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -36,6 +36,7 @@ class ConvAddReluKernel ...@@ -36,6 +36,7 @@ class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> { : public OpKernelBase<DeviceType, FusionConvAddReluParam> {
public: public:
void Compute(const FusionConvAddReluParam &param) const; void Compute(const FusionConvAddReluParam &param) const;
bool Init(const FusionConvAddReluParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -32,6 +32,7 @@ template <typename DeviceType, typename T> ...@@ -32,6 +32,7 @@ template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> { class ConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
}; };
inline bool IsExpand(const std::vector<int64_t> &filter_dim, inline bool IsExpand(const std::vector<int64_t> &filter_dim,
......
...@@ -31,6 +31,7 @@ template <typename DeviceType, typename T> ...@@ -31,6 +31,7 @@ template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> { class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -30,6 +30,7 @@ class ElementwiseAddKernel ...@@ -30,6 +30,7 @@ class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> { : public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
public: public:
void Compute(const ElementwiseAddParam &param) const; void Compute(const ElementwiseAddParam &param) const;
bool Init(const ElementwiseAddParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvKernel<FPGA, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void ConvKernel<FPGA, float>::Compute(const ConvParam &param) const {} void ConvKernel<FPGA, float>::Compute(const ConvParam &param) const {}
template class ConvKernel<FPGA, float>; template class ConvKernel<FPGA, float>;
......
...@@ -28,6 +28,7 @@ class FusionFcKernel ...@@ -28,6 +28,7 @@ class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> { : public framework::OpKernelBase<DeviceType, FusionFcParam> {
public: public:
void Compute(const FusionFcParam& param) const; void Compute(const FusionFcParam& param) const;
bool Init(const FusionFcParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -169,6 +169,7 @@ template <typename DeviceType, typename T> ...@@ -169,6 +169,7 @@ template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> { class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
public: public:
void Compute(const LrnParam &param) const; void Compute(const LrnParam &param) const;
bool Init(const LrnParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool BatchNormKernel<GPU_MALI, float>::Init(const BatchNormParam &para) const {
return true;
}
template <> template <>
void BatchNormKernel<GPU_MALI, float>::Compute( void BatchNormKernel<GPU_MALI, float>::Compute(
const BatchNormParam &param) const {} const BatchNormParam &param) const {}
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvKernel<GPU_MALI, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void ConvKernel<GPU_MALI, float>::Compute(const ConvParam &param) const { void ConvKernel<GPU_MALI, float>::Compute(const ConvParam &param) const {
// ArmConvImplement imp; // ArmConvImplement imp;
......
...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T> ...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T>
class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> { class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> {
public: public:
void Compute(const MulParam &param) const; void Compute(const MulParam &param) const;
bool Init(const MulParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -28,6 +28,7 @@ class MultiClassNMSKernel ...@@ -28,6 +28,7 @@ class MultiClassNMSKernel
: public framework::OpKernelBase<DeviceType, MultiClassNMSParam> { : public framework::OpKernelBase<DeviceType, MultiClassNMSParam> {
public: public:
void Compute(const MultiClassNMSParam& param) const; void Compute(const MultiClassNMSParam& param) const;
bool Init(const MultiClassNMSParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -28,6 +28,7 @@ template <typename DeviceType, typename T> ...@@ -28,6 +28,7 @@ template <typename DeviceType, typename T>
class PoolKernel : public OpKernelBase<DeviceType, PoolParam> { class PoolKernel : public OpKernelBase<DeviceType, PoolParam> {
public: public:
void Compute(const PoolParam &param) const override; void Compute(const PoolParam &param) const override;
bool Init(const PoolParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -54,6 +54,7 @@ class PriorBoxKernel ...@@ -54,6 +54,7 @@ class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam> { : public framework::OpKernelBase<DeviceType, PriorBoxParam> {
public: public:
void Compute(const PriorBoxParam& param) const; void Compute(const PriorBoxParam& param) const;
bool Init(const PriorBoxParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T> ...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> { class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
public: public:
void Compute(const ReluParam& param) const; void Compute(const ReluParam& param) const;
bool Init(const ReluParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -71,6 +71,7 @@ template <typename DeviceType, typename T> ...@@ -71,6 +71,7 @@ template <typename DeviceType, typename T>
class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> { class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
public: public:
void Compute(const ReshapeParam& param) const; void Compute(const ReshapeParam& param) const;
bool Init(const ReshapeParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -26,6 +26,7 @@ template <typename DeviceType, typename T> ...@@ -26,6 +26,7 @@ template <typename DeviceType, typename T>
class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> { class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> {
public: public:
void Compute(const SigmoidParam& param) const override; void Compute(const SigmoidParam& param) const override;
bool Init(const SigmoidParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T> ...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> { class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public: public:
void Compute(const SoftmaxParam &param) const override; void Compute(const SoftmaxParam &param) const override;
bool Init(const SoftmaxParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -29,6 +29,7 @@ class TransposeKernel ...@@ -29,6 +29,7 @@ class TransposeKernel
: public framework::OpKernelBase<DeviceType, TransposeParam> { : public framework::OpKernelBase<DeviceType, TransposeParam> {
public: public:
void Compute(const TransposeParam& param) const; void Compute(const TransposeParam& param) const;
bool Init(const TransposeParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册