提交 136c9f32 编写于 作者: L liuruilong

add kernel init method

上级 6543b2a5
......@@ -63,6 +63,7 @@ class OperatorBase {
std::vector<string> GetOutKeys() const;
virtual void RunImpl() const = 0;
virtual void Init() const = 0;
/*
* @b op 运算所需的输入, 如上一层的输出结果、卷积核
* */
......@@ -111,14 +112,18 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
param_(inputs, outputs, attrs, *scope) {
PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed",
this->type_.c_str());
}
virtual void RunImpl() const { this->kernel_.Compute(this->param_); }
virtual void InferShape() const = 0;
void Init() const {
PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed",
this->type_.c_str());
}
protected:
KernelType kernel_;
ParamType param_;
......
......@@ -198,6 +198,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
} else {
InitMemory();
}
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
for (const auto &op: ops) {
op->Init();
}
}
template <typename Dtype, Precision P>
......@@ -416,6 +423,8 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif
// to Run
ops[i]->Run();
#ifdef PADDLE_MOBILE_PROFILE
clock_gettime(CLOCK_MONOTONIC, &ts);
......
......@@ -32,6 +32,8 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void InferShape() const {
auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize();
......
......@@ -33,6 +33,8 @@ class FetchOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void InferShape() const {
auto x_dims = param_.InputX()->dims();
param_.Out()->Resize(x_dims);
......
......@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool BatchNormKernel<CPU, float>::Init(const BatchNormParam &para) const {
return true;
}
template <>
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const {
const Tensor *input_x = param.InputX();
......
......@@ -109,6 +109,11 @@ void DecodeCenterSize(const framework::Tensor& target_box,
}
}
template <>
bool BoxCoderKernel<CPU, float>::Init(const BoxCoderParam &para) const {
return true;
}
template <>
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam& param) const {
const auto* input_priorbox = param.InputPriorBox();
......
......@@ -52,6 +52,11 @@ class ConcatFunctor {
}
};
template <>
bool ConcatKernel<CPU, float>::Init(const ConcatParam &para) const {
return true;
}
template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
auto inputs = param.Inputs();
......
......@@ -18,6 +18,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddKernel<CPU, float>::Init(const FusionConvAddParam &para) const {
return true;
}
template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
const Tensor *input = param.Input();
......
......@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddReluKernel<CPU, float>::Init(const FusionConvAddReluParam &para) const {
return true;
}
template <>
void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam &param) const {
......
......@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(const ConvParam &para) const {
return true;
}
template <>
void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
const Tensor *input = param.Input();
......
......@@ -20,6 +20,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(const ConvParam &para) const {
return true;
}
template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
LOG(kLOG_DEBUG) << param;
......
......@@ -26,6 +26,11 @@ struct AddFunctor {
inline T operator()(T a, T b) const { return a + b; }
};
template <>
bool ElementwiseAddKernel<CPU, float>::Init(const ElementwiseAddParam &para) const {
return true;
}
template <>
void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam &param) const {
......
......@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<CPU, float>::Init(const FusionFcParam &para) const {
return true;
}
template <>
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const {
const Tensor *input_x = param.InputX();
......
......@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool LrnKernel<CPU, float>::Init(const LrnParam &para) const {
return true;
}
template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const Tensor *input_x = param.InputX();
......
......@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<CPU, float>::Init(const MulParam &para) const {
return true;
}
template <>
void MulKernel<CPU, float>::Compute(const MulParam &param) const {
const Tensor *input_x = param.InputX();
......
......@@ -205,6 +205,11 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
}
}
template <>
bool MultiClassNMSKernel<CPU, float>::Init(const MultiClassNMSParam &para) const {
return true;
}
template <>
void MultiClassNMSKernel<CPU, float>::Compute(
const MultiClassNMSParam& param) const {
......
......@@ -35,6 +35,11 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
}
}
template <>
bool PoolKernel<CPU, float>::Init(const PoolParam &para) const {
return true;
}
template <>
void PoolKernel<CPU, float>::Compute(const PoolParam &param) const {
const Tensor *in_x = param.Input();
......
......@@ -28,6 +28,11 @@ struct ClipFunctor {
}
};
template <>
bool PriorBoxKernel<CPU, float>::Init(const PriorBoxParam &para) const {
return true;
}
template <>
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const {
const auto *input_ = param.Input();
......
......@@ -27,6 +27,11 @@ struct ReluFunctor {
inline T operator()(T in) const { return in > 0 ? in : 0; }
};
template <>
bool ReluKernel<CPU, float>::Init(const ReluParam &para) const {
return true;
}
/*
* @b 特化到具体平台的实现, param 从 op 层传入
* */
......
......@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool ReshapeKernel<CPU, float>::Init(const ReshapeParam &para) const {
return true;
}
template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
const auto *input_x = param.InputX();
......
......@@ -71,6 +71,11 @@ void sigmoid(const Tensor *X, Tensor *Y) {
#endif
}
template <>
bool SigmoidKernel<CPU, float>::Init(const SigmoidParam &para) const {
return true;
}
template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const {
const Tensor *in_x = param.InputX();
......
......@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool SoftmaxKernel<CPU, float>::Init(const SoftmaxParam &para) const {
return true;
}
template <>
void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const {
const Tensor *in_x = param.InputX();
......
......@@ -34,6 +34,11 @@ namespace operators {
// }
// }
template <>
bool TransposeKernel<CPU, float>::Init(const TransposeParam &para) const {
return true;
}
template <>
void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
const auto* input_x = param.InputX();
......
......@@ -29,6 +29,7 @@ class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> {
public:
void Compute(const BatchNormParam &param) const;
bool Init(const BatchNormParam &para) const;
};
} // namespace operators
......
......@@ -30,6 +30,7 @@ class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> {
public:
void Compute(const BoxCoderParam& param) const;
bool Init(const BoxCoderParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -27,6 +27,7 @@ template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
public:
void Compute(const ConcatParam &param) const;
bool Init(const ConcatParam &para) const;
};
} // namespace operators
......
......@@ -38,6 +38,7 @@ template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> {
public:
void Compute(const FusionConvAddParam &param) const;
bool Init(const FusionConvAddParam &para) const;
};
} // namespace operators
......
......@@ -36,6 +36,7 @@ class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> {
public:
void Compute(const FusionConvAddReluParam &param) const;
bool Init(const FusionConvAddReluParam &para) const;
};
} // namespace operators
......
......@@ -32,6 +32,7 @@ template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public:
void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
};
inline bool IsExpand(const std::vector<int64_t> &filter_dim,
......
......@@ -31,6 +31,7 @@ template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public:
void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -30,6 +30,7 @@ class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
public:
void Compute(const ElementwiseAddParam &param) const;
bool Init(const ElementwiseAddParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<FPGA, float>::Init(const ConvParam &para) const {
return true;
}
template <>
void ConvKernel<FPGA, float>::Compute(const ConvParam &param) const {}
template class ConvKernel<FPGA, float>;
......
......@@ -28,6 +28,7 @@ class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> {
public:
void Compute(const FusionFcParam& param) const;
bool Init(const FusionFcParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -169,6 +169,7 @@ template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
public:
void Compute(const LrnParam &param) const;
bool Init(const LrnParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool BatchNormKernel<GPU_MALI, float>::Init(const BatchNormParam &para) const {
return true;
}
template <>
void BatchNormKernel<GPU_MALI, float>::Compute(
const BatchNormParam &param) const {}
......
......@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<GPU_MALI, float>::Init(const ConvParam &para) const {
return true;
}
template <>
void ConvKernel<GPU_MALI, float>::Compute(const ConvParam &param) const {
// ArmConvImplement imp;
......
......@@ -29,6 +29,7 @@ template <typename DeviceType, typename T>
class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> {
public:
void Compute(const MulParam &param) const;
bool Init(const MulParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -28,6 +28,7 @@ class MultiClassNMSKernel
: public framework::OpKernelBase<DeviceType, MultiClassNMSParam> {
public:
void Compute(const MultiClassNMSParam& param) const;
bool Init(const MultiClassNMSParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -28,6 +28,7 @@ template <typename DeviceType, typename T>
class PoolKernel : public OpKernelBase<DeviceType, PoolParam> {
public:
void Compute(const PoolParam &param) const override;
bool Init(const PoolParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -54,6 +54,7 @@ class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam> {
public:
void Compute(const PriorBoxParam& param) const;
bool Init(const PriorBoxParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -27,6 +27,7 @@ template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
public:
void Compute(const ReluParam& param) const;
bool Init(const ReluParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -71,6 +71,7 @@ template <typename DeviceType, typename T>
class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
public:
void Compute(const ReshapeParam& param) const;
bool Init(const ReshapeParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -26,6 +26,7 @@ template <typename DeviceType, typename T>
class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> {
public:
void Compute(const SigmoidParam& param) const override;
bool Init(const SigmoidParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -29,6 +29,7 @@ template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public:
void Compute(const SoftmaxParam &param) const override;
bool Init(const SoftmaxParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -29,6 +29,7 @@ class TransposeKernel
: public framework::OpKernelBase<DeviceType, TransposeParam> {
public:
void Compute(const TransposeParam& param) const;
bool Init(const TransposeParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册