提交 1346439b 编写于 作者: N nhzlx

add template for op param

上级 2808898a
......@@ -26,14 +26,15 @@ namespace operators {
using std::string;
template <typename DeviceType, typename T>
class BatchNormOp
: public framework::OperatorWithKernel<DeviceType, BatchNormParam,
: public framework::OperatorWithKernel<DeviceType,
BatchNormParam<DeviceType>,
BatchNormKernel<DeviceType, T>> {
public:
BatchNormOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, BatchNormParam,
: framework::OperatorWithKernel<DeviceType, BatchNormParam<DeviceType>,
BatchNormKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
......
......@@ -28,20 +28,20 @@ namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class BoxCoderOp
: public framework::OperatorWithKernel<
DeviceType, BoxCoderParam, operators::BoxCoderKernel<DeviceType, T>> {
class BoxCoderOp : public framework::OperatorWithKernel<
DeviceType, BoxCoderParam<DeviceType>,
operators::BoxCoderKernel<DeviceType, T>> {
public:
BoxCoderOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, BoxCoderParam,
: framework::OperatorWithKernel<DeviceType, BoxCoderParam<DeviceType>,
operators::BoxCoderKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, BoxCoderParam,
DeviceType, BoxCoderParam<DeviceType>,
operators::BoxCoderKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -24,19 +24,19 @@ namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class ConcatOp
: public framework::OperatorWithKernel<
DeviceType, ConcatParam, operators::ConcatKernel<DeviceType, T>> {
class ConcatOp : public framework::OperatorWithKernel<
DeviceType, ConcatParam<DeviceType>,
operators::ConcatKernel<DeviceType, T>> {
public:
ConcatOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ConcatParam,
: framework::OperatorWithKernel<DeviceType, ConcatParam<DeviceType>,
operators::ConcatKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ConcatParam,
DeviceType, ConcatParam<DeviceType>,
operators::ConcatKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -24,19 +24,19 @@ namespace paddle_mobile {
namespace operators {
using std::string;
template <typename DeviceType, typename T>
class ConvOp
: public framework::OperatorWithKernel<
DeviceType, ConvParam, operators::ConvKernel<DeviceType, T>> {
class ConvOp : public framework::OperatorWithKernel<
DeviceType, ConvParam<DeviceType>,
operators::ConvKernel<DeviceType, T>> {
public:
ConvOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ConvParam,
: framework::OperatorWithKernel<DeviceType, ConvParam<DeviceType>,
operators::ConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ConvParam,
DeviceType, ConvParam<DeviceType>,
operators::ConvKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class ConvOpTranspose : public framework::OperatorWithKernel<
DeviceType, ConvTransposeParam,
DeviceType, ConvTransposeParam<DeviceType>,
operators::ConvTransposeKernel<DeviceType, T>> {
public:
ConvOpTranspose(const std::string &type, const VariableNameMap &inputs,
......@@ -34,7 +34,7 @@ class ConvOpTranspose : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ConvTransposeParam,
DeviceType, ConvTransposeParam<DeviceType>,
operators::ConvTransposeKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
......
......@@ -25,7 +25,7 @@ namespace operators {
template <typename DeviceType, typename T>
class DepthwiseConvOp : public framework::OperatorWithKernel<
DeviceType, ConvParam,
DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>> {
public:
DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs,
......@@ -33,12 +33,12 @@ class DepthwiseConvOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ConvParam,
DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ConvParam,
DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -28,18 +28,18 @@ namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class DropoutOp
: public framework::OperatorWithKernel<
DeviceType, DropoutParam, operators::DropoutKernel<DeviceType, T>> {
class DropoutOp : public framework::OperatorWithKernel<
DeviceType, DropoutParam<DeviceType>,
operators::DropoutKernel<DeviceType, T>> {
public:
DropoutOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, DropoutParam,
: framework::OperatorWithKernel<DeviceType, DropoutParam<DeviceType>,
operators::DropoutKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
// using framework::OperatorWithKernel<DeviceType, DropoutParam,
// using framework::OperatorWithKernel<DeviceType, DropoutParam<DeviceType>,
// operators::DropoutKernel<DeviceType,
// T>>;
void InferShape() const override;
......
......@@ -26,7 +26,7 @@ namespace operators {
using std::string;
template <typename DeviceType, typename T>
class ElementwiseAddOp : public framework::OperatorWithKernel<
DeviceType, ElementwiseAddParam,
DeviceType, ElementwiseAddParam<DeviceType>,
operators::ElementwiseAddKernel<DeviceType, T>> {
public:
ElementwiseAddOp(const string &type, const VariableNameMap &inputs,
......@@ -34,12 +34,12 @@ class ElementwiseAddOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ElementwiseAddParam,
DeviceType, ElementwiseAddParam<DeviceType>,
operators::ElementwiseAddKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, ElementwiseAddParam,
DeviceType, ElementwiseAddParam<DeviceType>,
operators::ElementwiseAddKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -61,7 +61,7 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
#endif
protected:
FeedParam param_;
FeedParam<DeviceType> param_;
};
} // namespace operators
......
......@@ -41,7 +41,7 @@ class FetchOp : public framework::OperatorBase<DeviceType> {
}
protected:
FetchParam param_;
FetchParam<DeviceType> param_;
};
} // namespace operators
......
......@@ -45,19 +45,20 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T>
class FusionConvAddOp : public framework::OperatorWithKernel<
DeviceType, FusionConvAddParam,
DeviceType, FusionConvAddParam<DeviceType>,
operators::ConvAddKernel<DeviceType, T>> {
public:
FusionConvAddOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, FusionConvAddParam,
: framework::OperatorWithKernel<DeviceType,
FusionConvAddParam<DeviceType>,
operators::ConvAddKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvAddParam,
DeviceType, FusionConvAddParam<DeviceType>,
operators::ConvAddKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -53,7 +53,7 @@ class FusionConvAddBNMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T>
class FusionConvAddBNOp : public framework::OperatorWithKernel<
DeviceType, FusionConvAddBNParam,
DeviceType, FusionConvAddBNParam<DeviceType>,
operators::ConvAddBNKernel<DeviceType, T>> {
public:
FusionConvAddBNOp(const string &type, const VariableNameMap &inputs,
......@@ -61,7 +61,7 @@ class FusionConvAddBNOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvAddBNParam,
DeviceType, FusionConvAddBNParam<DeviceType>,
operators::ConvAddBNKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
......
......@@ -55,7 +55,7 @@ class FusionConvAddBNReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T>
class FusionConvAddBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam,
DeviceType, FusionConvAddBNReluParam<DeviceType>,
operators::ConvAddBNReluKernel<DeviceType, T>> {
public:
FusionConvAddBNReluOp(const string &type, const VariableNameMap &inputs,
......@@ -63,12 +63,12 @@ class FusionConvAddBNReluOp
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam,
DeviceType, FusionConvAddBNReluParam<DeviceType>,
operators::ConvAddBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam,
DeviceType, FusionConvAddBNReluParam<DeviceType>,
operators::ConvAddBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -43,7 +43,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T>
class FusionConvAddReluOp : public framework::OperatorWithKernel<
DeviceType, FusionConvAddReluParam,
DeviceType, FusionConvAddReluParam<DeviceType>,
operators::ConvAddReluKernel<DeviceType, T>> {
public:
FusionConvAddReluOp(const string &type, const VariableNameMap &inputs,
......@@ -51,12 +51,12 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvAddReluParam,
DeviceType, FusionConvAddReluParam<DeviceType>,
operators::ConvAddReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvAddReluParam,
DeviceType, FusionConvAddReluParam<DeviceType>,
operators::ConvAddReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -52,7 +52,7 @@ class FusionConvBNReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T>
class FusionConvBNReluOp : public framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam,
DeviceType, FusionConvBNReluParam<DeviceType>,
operators::ConvBNReluKernel<DeviceType, T>> {
public:
FusionConvBNReluOp(const string &type, const VariableNameMap &inputs,
......@@ -60,12 +60,12 @@ class FusionConvBNReluOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam,
DeviceType, FusionConvBNReluParam<DeviceType>,
operators::ConvBNReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam,
DeviceType, FusionConvBNReluParam<DeviceType>,
operators::ConvBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -51,21 +51,22 @@ class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher {
};
template <typename DeviceType, typename T>
class FusionDWConvBNReluOp : public framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam,
operators::DWConvBNReluKernel<DeviceType, T>> {
class FusionDWConvBNReluOp
: public framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam<DeviceType>,
operators::DWConvBNReluKernel<DeviceType, T>> {
public:
FusionDWConvBNReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam,
DeviceType, FusionDWConvBNReluParam<DeviceType>,
operators::DWConvBNReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam,
DeviceType, FusionDWConvBNReluParam<DeviceType>,
operators::DWConvBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -44,7 +44,7 @@ class FusioneElementwiseAddReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T>
class FusionElementwiseAddReluOp
: public framework::OperatorWithKernel<
DeviceType, ElementwiseAddReluParam,
DeviceType, ElementwiseAddReluParam<DeviceType>,
operators::ElementwiseAddReluKernel<DeviceType, T>> {
public:
FusionElementwiseAddReluOp(const string &type, const VariableNameMap &inputs,
......@@ -52,7 +52,7 @@ class FusionElementwiseAddReluOp
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, ElementwiseAddReluParam,
DeviceType, ElementwiseAddReluParam<DeviceType>,
operators::ElementwiseAddReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
......
......@@ -45,20 +45,20 @@ class FusionFcMatcher : public framework::FusionOpMatcher {
};
template <typename DeviceType, typename T>
class FusionFcOp
: public framework::OperatorWithKernel<
DeviceType, FusionFcParam, operators::FusionFcKernel<DeviceType, T>> {
class FusionFcOp : public framework::OperatorWithKernel<
DeviceType, FusionFcParam<DeviceType>,
operators::FusionFcKernel<DeviceType, T>> {
public:
FusionFcOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, FusionFcParam,
: framework::OperatorWithKernel<DeviceType, FusionFcParam<DeviceType>,
operators::FusionFcKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionFcParam,
DeviceType, FusionFcParam<DeviceType>,
operators::FusionFcKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -44,7 +44,7 @@ class FusionFcReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T>
class FusionFcReluOp : public framework::OperatorWithKernel<
DeviceType, FusionFcReluParam,
DeviceType, FusionFcReluParam<DeviceType>,
operators::FusionFcReluKernel<DeviceType, T>> {
public:
FusionFcReluOp(const string &type, const VariableNameMap &inputs,
......@@ -52,12 +52,12 @@ class FusionFcReluOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, FusionFcReluParam,
DeviceType, FusionFcReluParam<DeviceType>,
operators::FusionFcReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
using framework::OperatorWithKernel<
DeviceType, FusionFcReluParam,
DeviceType, FusionFcReluParam<DeviceType>,
operators::FusionFcReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -27,7 +27,7 @@ using namespace framework;
template <typename DeviceType, typename T>
class Im2SequenceOp : public framework::OperatorWithKernel<
DeviceType, Im2SequenceParam,
DeviceType, Im2SequenceParam<DeviceType>,
operators::Im2SequenceKernel<DeviceType, T>> {
public:
Im2SequenceOp(const std::string &type, const VariableNameMap &inputs,
......@@ -35,12 +35,12 @@ class Im2SequenceOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<
DeviceType, Im2SequenceParam,
DeviceType, Im2SequenceParam<DeviceType>,
operators::Im2SequenceKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {}
// using framework::OperatorWithKernel<
// DeviceType, Im2SequenceParam,
// DeviceType, Im2SequenceParam<DeviceType>,
// operators::Im2SequenceKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
......
......@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool BatchNormKernel<CPU, float>::Init(BatchNormParam *param) {
bool BatchNormKernel<CPU, float>::Init(BatchNormParam<CPU> *param) {
return true;
}
template <>
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const {
void BatchNormKernel<CPU, float>::Compute(
const BatchNormParam<CPU> &param) const {
BatchnormCompute<float>(param);
}
......
......@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam *param) {
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam<CPU> *param) {
return true;
}
template <>
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam &param) const {
void BoxCoderKernel<CPU, float>::Compute(
const BoxCoderParam<CPU> &param) const {
BoxCoderCompute<float>(param);
}
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConcatKernel<CPU, float>::Init(ConcatParam *param) {
bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
return true;
}
template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) const {
ConcatCompute<float>(param);
}
......
......@@ -21,7 +21,8 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) {
bool ConvAddBNReluKernel<CPU, float>::Init(
FusionConvAddBNReluParam<CPU> *param) {
const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
......@@ -54,7 +55,7 @@ bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) {
template <>
void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam &param) const {
const FusionConvAddBNReluParam<CPU> &param) const {
ConvAddBNReluCompute<float>(param);
}
template class ConvAddBNReluKernel<CPU, float>;
......
......@@ -20,12 +20,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) {
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
return true;
}
template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
void ConvAddKernel<CPU, float>::Compute(
const FusionConvAddParam<CPU> &param) const {
ConvAddCompute<float>(param);
}
......
......@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam *param) {
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam<CPU> *param) {
return true;
}
template <>
void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam &param) const {
const FusionConvAddReluParam<CPU> &param) const {
ConvAddReluCompute<float>(param);
}
template class ConvAddReluKernel<CPU, float>;
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam *param) {
bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) {
const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
......@@ -57,7 +57,7 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam *param) {
template <>
void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam &param) const {
const FusionConvBNReluParam<CPU> &param) const {
ConvBNReluCompute<float>(param);
}
template class ConvBNReluKernel<CPU, float>;
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(ConvParam *param) {
bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
return true;
}
template <>
void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) const {
ConvCompute<float>(param);
}
......
......@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvTransposeKernel<CPU, float>::Init(ConvTransposeParam *param) {
bool ConvTransposeKernel<CPU, float>::Init(ConvTransposeParam<CPU> *param) {
return true;
}
template <>
void ConvTransposeKernel<CPU, float>::Compute(
const ConvTransposeParam &param) const {
const ConvTransposeParam<CPU> &param) const {
ConvTransposeCompute<float>(param);
}
......
......@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam *param) {
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
return true;
}
template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
void DepthwiseConvKernel<CPU, float>::Compute(
const ConvParam<CPU> &param) const {
DepthwiseConvCompute<float>(param);
}
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool DropoutKernel<CPU, float>::Init(DropoutParam *para) {
bool DropoutKernel<CPU, float>::Init(DropoutParam<CPU> *para) {
return true;
}
......@@ -31,7 +31,7 @@ struct DropoutFunctor {
};
template <>
void DropoutKernel<CPU, float>::Compute(const DropoutParam &param) const {
void DropoutKernel<CPU, float>::Compute(const DropoutParam<CPU> &param) const {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam *param) {
bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam<CPU> *param) {
const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale();
......@@ -54,7 +54,7 @@ bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam *param) {
template <>
void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam &param) const {
const FusionDWConvBNReluParam<CPU> &param) const {
DWConvBNReluCompute<float>(param);
}
template class DWConvBNReluKernel<CPU, float>;
......
......@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) {
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam<CPU> *param) {
return true;
}
template <>
void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam &param) const {
const ElementwiseAddParam<CPU> &param) const {
ElementwiseAddCompute<float>(param);
}
......
......@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) {
bool FusionFcKernel<CPU, float>::Init(FusionFcParam<CPU> *param) {
return true;
}
template <>
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const {
void FusionFcKernel<CPU, float>::Compute(
const FusionFcParam<CPU> &param) const {
FusionFcCompute<float>(param);
}
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool Im2SequenceKernel<CPU, float>::Init(Im2SequenceParam *para) {
bool Im2SequenceKernel<CPU, float>::Init(Im2SequenceParam<CPU> *para) {
return true;
}
......@@ -33,7 +33,7 @@ inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0,
template <>
void Im2SequenceKernel<CPU, float>::Compute(
const Im2SequenceParam &param) const {
const Im2SequenceParam<CPU> &param) const {
const Tensor *in_x = param.Input();
Tensor *out = param.Output();
out->mutable_data<float>();
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool LrnKernel<CPU, float>::Init(LrnParam *param) {
bool LrnKernel<CPU, float>::Init(LrnParam<CPU> *param) {
return true;
}
template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
void LrnKernel<CPU, float>::Compute(const LrnParam<CPU> &param) const {
LrnCompute<float>(param);
}
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool MulKernel<CPU, float>::Init(MulParam *param) {
bool MulKernel<CPU, float>::Init(MulParam<CPU> *param) {
return true;
}
template <>
void MulKernel<CPU, float>::Compute(const MulParam &param) const {
void MulKernel<CPU, float>::Compute(const MulParam<CPU> &param) const {
MulCompute<float>(param);
}
......
......@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam *param) {
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam<CPU> *param) {
return true;
}
template <>
void MultiClassNMSKernel<CPU, float>::Compute(
const MultiClassNMSParam &param) const {
const MultiClassNMSParam<CPU> &param) const {
MultiClassNMSCompute<float>(param);
}
......
......@@ -20,12 +20,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool PoolKernel<CPU, float>::Init(PoolParam *param) {
bool PoolKernel<CPU, float>::Init(PoolParam<CPU> *param) {
return true;
}
template <>
void PoolKernel<CPU, float>::Compute(const PoolParam &param) const {
void PoolKernel<CPU, float>::Compute(const PoolParam<CPU> &param) const {
PoolCompute<float>(param);
}
} // namespace operators
......
......@@ -32,7 +32,7 @@ struct PReluFunctor {
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <>
void PReluKernel<CPU, float>::Compute(const PReluParam &param) const {
void PReluKernel<CPU, float>::Compute(const PReluParam<CPU> &param) const {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
......
......@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) {
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam<CPU> *param) {
return true;
}
template <>
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const {
void PriorBoxKernel<CPU, float>::Compute(
const PriorBoxParam<CPU> &param) const {
PriorBoxCompute<float>(param);
}
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) {
bool ReluKernel<CPU, float>::Init(ReluParam<CPU> *param) {
return true;
}
template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const {
void ReluKernel<CPU, float>::Compute(const ReluParam<CPU> &param) const {
ReluCompute<float>(param);
}
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) {
bool ReshapeKernel<CPU, float>::Init(ReshapeParam<CPU> *param) {
return true;
}
template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam<CPU> &param) const {
ReshapeCompute<float>(param);
}
......
......@@ -108,7 +108,7 @@ void ResizeTensor(const Tensor* src, Tensor* dst) {
}
template <>
void ResizeKernel<CPU, float>::Compute(const ResizeParam& param) const {
void ResizeKernel<CPU, float>::Compute(const ResizeParam<CPU>& param) const {
const auto* input_x = param.InputX();
const auto& input_x_dims = input_x->dims();
auto* out = param.Out();
......
......@@ -23,7 +23,7 @@ namespace operators {
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <>
void ScaleKernel<CPU, float>::Compute(const ScaleParam &param) const {
void ScaleKernel<CPU, float>::Compute(const ScaleParam<CPU> &param) const {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
......
......@@ -27,12 +27,12 @@ using framework::DDim;
using framework::Tensor;
template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) {
bool SigmoidKernel<CPU, float>::Init(SigmoidParam<CPU> *param) {
return true;
}
template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const {
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam<CPU> &param) const {
SigmoidCompute<float>(param);
}
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) {
bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam<CPU> *param) {
return true;
}
template <>
void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const {
void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam<CPU> &param) const {
SoftmaxCompute<float>(param);
}
......
......@@ -20,12 +20,13 @@ namespace paddle_mobile {
namespace operators {
template <>
bool TransposeKernel<CPU, float>::Init(TransposeParam *param) {
bool TransposeKernel<CPU, float>::Init(TransposeParam<CPU> *param) {
return true;
}
template <>
void TransposeKernel<CPU, float>::Compute(const TransposeParam &param) const {
void TransposeKernel<CPU, float>::Compute(
const TransposeParam<CPU> &param) const {
TransposeCompute<float>(param);
}
......
......@@ -26,10 +26,10 @@ using namespace framework;
template <typename DeviceType, typename T>
class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> {
: public framework::OpKernelBase<DeviceType, BatchNormParam<DeviceType>> {
public:
void Compute(const BatchNormParam &param) const;
bool Init(BatchNormParam *param);
void Compute(const BatchNormParam<DeviceType> &param) const;
bool Init(BatchNormParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -27,10 +27,10 @@ namespace operators {
template <typename DeviceType, typename T>
class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> {
: public framework::OpKernelBase<DeviceType, BoxCoderParam<DeviceType>> {
public:
void Compute(const BoxCoderParam& param) const;
bool Init(BoxCoderParam* param);
void Compute(const BoxCoderParam<DeviceType>& param) const;
bool Init(BoxCoderParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -23,7 +23,7 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void BatchnormCompute(const BatchNormParam &param) {
void BatchnormCompute(const BatchNormParam<CPU> &param) {
const Tensor *input_x = param.InputX();
auto input_x_ptr = input_x->data<float>();
const auto &x_dims = input_x->dims();
......
......@@ -113,7 +113,7 @@ void DecodeCenterSize(const framework::Tensor& target_box,
}
template <typename P>
void BoxCoderCompute(const BoxCoderParam& param) {
void BoxCoderCompute(const BoxCoderParam<CPU>& param) {
const auto* input_priorbox = param.InputPriorBox();
const auto* input_priorboxvar = param.InputPriorBoxVar();
const auto* input_targetbox = param.InputTargetBox();
......
......@@ -54,7 +54,7 @@ class ConcatFunctor {
};
template <typename P>
void ConcatCompute(const ConcatParam &param) {
void ConcatCompute(const ConcatParam<CPU> &param) {
auto inputs = param.Inputs();
auto *out = param.Out();
int64_t axis = param.Axis();
......
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
void ConvAddBasic(const FusionConvAddParam &param) {
void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
......@@ -114,7 +114,7 @@ void ConvAddBasic(const FusionConvAddParam &param) {
}
template <typename P>
void ConvAddCompute(const FusionConvAddParam &param) {
void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
......@@ -112,7 +112,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
}
}
template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) {
void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
......
......@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void ConvAddReluCompute(const FusionConvAddReluParam &param) {
void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor bias = *param.Bias();
......
......@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
inline void ConvBasic(const ConvParam &param) {
inline void ConvBasic(const ConvParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor *output = param.Output();
......@@ -112,7 +112,7 @@ inline void ConvBasic(const ConvParam &param) {
}
template <typename P>
void ConvCompute(const ConvParam &param) {
void ConvCompute(const ConvParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam &param) {
void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
......@@ -113,7 +113,7 @@ void ConvBNReluBasic(const FusionConvBNReluParam &param) {
}
template <typename P>
void ConvBNReluCompute(const FusionConvBNReluParam &param) {
void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
......@@ -28,7 +28,7 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void ConvTransposeCompute(const ConvTransposeParam &param) {
void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor *output = param.Output();
......
......@@ -25,7 +25,7 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void DepthwiseConvCompute(const ConvParam &param) {
void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Tensor Bias;
Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] &&
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) {
void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input();
Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias();
......@@ -111,7 +111,7 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) {
}
}
template <typename P>
void DWConvBNReluCompute(const FusionDWConvBNReluParam &param) {
void DWConvBNReluCompute(const FusionDWConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
......@@ -27,7 +27,7 @@ struct AddFunctor {
};
template <typename P>
void ElementwiseAddCompute(const ElementwiseAddParam &param) {
void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *Out = param.Out();
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void FusionFcCompute(const FusionFcParam &param) {
void FusionFcCompute(const FusionFcParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
const Tensor *input_z = param.InputZ();
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void LrnCompute(const LrnParam &param) {
void LrnCompute(const LrnParam<CPU> &param) {
const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims();
Tensor *out = param.Out();
......
......@@ -54,7 +54,7 @@ namespace operators {
// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)
template <typename P>
void MulCompute(const MulParam &param) {
void MulCompute(const MulParam<CPU> &param) {
const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY();
Tensor *out = param.Out();
......
......@@ -213,7 +213,7 @@ void MultiClassOutput(const framework::Tensor& scores,
}
template <typename P>
void MultiClassNMSCompute(const MultiClassNMSParam& param) {
void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
const auto* input_bboxes = param.InputBBoxes();
const auto& input_bboxes_dims = input_bboxes->dims();
......
......@@ -38,7 +38,7 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
}
}
template <typename P>
void PoolCompute(const PoolParam &param) {
void PoolCompute(const PoolParam<CPU> &param) {
const Tensor *in_x = param.Input();
Tensor *out = param.Output();
std::string pooling_type = param.PoolingType();
......@@ -58,7 +58,8 @@ void PoolCompute(const PoolParam &param) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
} else if (ksize[0] == 3 && ksize[0] == ksize[1]) {
}
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) {
......
......@@ -29,7 +29,7 @@ struct ClipFunctor {
};
template <typename P>
void PriorBoxCompute(const PriorBoxParam &param) {
void PriorBoxCompute(const PriorBoxParam<CPU> &param) {
const auto *input_ = param.Input();
const auto &input_dims = input_->dims();
......
......@@ -30,7 +30,7 @@ struct ReluFunctor {
* @b 特化到具体平台的实现, param 从 op 层传入
* */
template <typename P>
void ReluCompute(const ReluParam &param) {
void ReluCompute(const ReluParam<CPU> &param) {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
......
......@@ -23,7 +23,7 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void ReshapeCompute(const ReshapeParam &param) {
void ReshapeCompute(const ReshapeParam<CPU> &param) {
const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims();
auto *out = param.Out();
......
......@@ -73,7 +73,7 @@ void sigmoid(const Tensor *X, Tensor *Y) {
}
template <typename P>
void SigmoidCompute(const SigmoidParam &param) {
void SigmoidCompute(const SigmoidParam<CPU> &param) {
const Tensor *in_x = param.InputX();
Tensor *out = param.Out();
auto x_dims = in_x->dims();
......
......@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <typename P>
void SoftmaxCompute(const SoftmaxParam &param) {
void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
const Tensor *in_x = param.InputX();
Tensor *out = param.Out();
auto x_dims = in_x->dims();
......
......@@ -39,7 +39,7 @@ namespace operators {
// }
template <typename P>
void TransposeCompute(const TransposeParam& param) {
void TransposeCompute(const TransposeParam<CPU>& param) {
const auto* input_x = param.InputX();
const auto input_x_dims = input_x->dims();
auto* out = param.Out();
......
......@@ -24,10 +24,11 @@ namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
class ConcatKernel
: public framework::OpKernelBase<DeviceType, ConcatParam<DeviceType>> {
public:
void Compute(const ConcatParam &param) const;
bool Init(ConcatParam *param);
void Compute(const ConcatParam<DeviceType> &param) const;
bool Init(ConcatParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -32,10 +32,11 @@ using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvAddBNKernel : public OpKernelBase<DeviceType, FusionConvAddBNParam> {
class ConvAddBNKernel
: public OpKernelBase<DeviceType, FusionConvAddBNParam<DeviceType>> {
public:
void Compute(const FusionConvAddBNParam &param) const;
bool Init(FusionConvAddBNParam *param);
void Compute(const FusionConvAddBNParam<DeviceType> &param) const;
bool Init(FusionConvAddBNParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvAddBNReluKernel
: public OpKernelBase<DeviceType, FusionConvAddBNReluParam> {
: public OpKernelBase<DeviceType, FusionConvAddBNReluParam<DeviceType>> {
public:
void Compute(const FusionConvAddBNReluParam &param) const;
bool Init(FusionConvAddBNReluParam *param);
void Compute(const FusionConvAddBNReluParam<DeviceType> &param) const;
bool Init(FusionConvAddBNReluParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -37,10 +37,11 @@ using framework::DDim;
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> {
class ConvAddKernel
: public OpKernelBase<DeviceType, FusionConvAddParam<DeviceType>> {
public:
void Compute(const FusionConvAddParam &param) const;
bool Init(FusionConvAddParam *param);
void Compute(const FusionConvAddParam<DeviceType> &param) const;
bool Init(FusionConvAddParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> {
: public OpKernelBase<DeviceType, FusionConvAddReluParam<DeviceType>> {
public:
void Compute(const FusionConvAddReluParam &param) const;
bool Init(FusionConvAddReluParam *param);
void Compute(const FusionConvAddReluParam<DeviceType> &param) const;
bool Init(FusionConvAddReluParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvBNReluKernel
: public OpKernelBase<DeviceType, FusionConvBNReluParam> {
: public OpKernelBase<DeviceType, FusionConvBNReluParam<DeviceType>> {
public:
void Compute(const FusionConvBNReluParam &param) const;
bool Init(FusionConvBNReluParam *param);
void Compute(const FusionConvBNReluParam<DeviceType> &param) const;
bool Init(FusionConvBNReluParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -29,10 +29,10 @@ namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> {
class ConvKernel : public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public:
void Compute(const ConvParam &param) const;
bool Init(ConvParam *param);
void Compute(const ConvParam<DeviceType> &param) const;
bool Init(ConvParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -26,11 +26,11 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T>
class ConvTransposeKernel
: public OpKernelBase<DeviceType, ConvTransposeParam> {
: public OpKernelBase<DeviceType, ConvTransposeParam<DeviceType>> {
public:
void Compute(const ConvTransposeParam &param) const;
void Compute(const ConvTransposeParam<DeviceType> &param) const;
bool Init(ConvTransposeParam *param);
bool Init(ConvTransposeParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -28,10 +28,11 @@ namespace operators {
using framework::OpKernelBase;
template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> {
class DepthwiseConvKernel
: public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public:
void Compute(const ConvParam &param) const;
bool Init(ConvParam *param);
void Compute(const ConvParam<DeviceType> &param) const;
bool Init(ConvParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -23,10 +23,11 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> {
class DropoutKernel
: public framework::OpKernelBase<DeviceType, DropoutParam<DeviceType>> {
public:
void Compute(const DropoutParam& param) const;
bool Init(DropoutParam* para);
void Compute(const DropoutParam<DeviceType>& param) const;
bool Init(DropoutParam<DeviceType>* para);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T>
class DWConvBNReluKernel
: public OpKernelBase<DeviceType, FusionDWConvBNReluParam> {
: public OpKernelBase<DeviceType, FusionDWConvBNReluParam<DeviceType>> {
public:
void Compute(const FusionDWConvBNReluParam &param) const;
bool Init(FusionDWConvBNReluParam *param);
void Compute(const FusionDWConvBNReluParam<DeviceType> &param) const;
bool Init(FusionDWConvBNReluParam<DeviceType> *param);
};
} // namespace operators
......
......@@ -27,10 +27,11 @@ using namespace framework;
template <typename DeviceType, typename T>
class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
: public framework::OpKernelBase<DeviceType,
ElementwiseAddParam<DeviceType>> {
public:
void Compute(const ElementwiseAddParam &param) const;
bool Init(ElementwiseAddParam *param);
void Compute(const ElementwiseAddParam<DeviceType> &param) const;
bool Init(ElementwiseAddParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -26,10 +26,11 @@ using namespace framework;
template <typename DeviceType, typename T>
class ElementwiseAddReluKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddReluParam> {
: public framework::OpKernelBase<DeviceType,
ElementwiseAddReluParam<DeviceType>> {
public:
void Compute(const ElementwiseAddReluParam &param) const;
bool Init(ElementwiseAddReluParam *param);
void Compute(const ElementwiseAddReluParam<DeviceType> &param) const;
bool Init(ElementwiseAddReluParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -25,10 +25,11 @@ namespace operators {
template <typename DeviceType, typename T>
class FusionFcReluKernel
: public framework::OpKernelBase<DeviceType, FusionFcReluParam> {
: public framework::OpKernelBase<DeviceType,
FusionFcReluParam<DeviceType>> {
public:
void Compute(const FusionFcReluParam& param) const;
bool Init(FusionFcReluParam* param);
void Compute(const FusionFcReluParam<DeviceType>& param) const;
bool Init(FusionFcReluParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -20,12 +20,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConcatKernel<FPGA, float>::Init(ConcatParam *param) {
bool ConcatKernel<FPGA, float>::Init(ConcatParam<FPGA> *param) {
return true;
}
template <>
void ConcatKernel<FPGA, float>::Compute(const ConcatParam &param) const {
void ConcatKernel<FPGA, float>::Compute(const ConcatParam<FPGA> &param) const {
auto inputs = param.Inputs();
auto *out = param.Out();
int64_t axis = param.Axis();
......
......@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) {
bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam<FPGA> *param) {
bool relu_enabled = false;
const Tensor *input = param->Input();
auto input_ptr = input->data<half>();
......@@ -92,7 +92,7 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) {
template <>
void ConvAddBNKernel<FPGA, float>::Compute(
const FusionConvAddBNParam &param) const {
const FusionConvAddBNParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
template class ConvAddBNKernel<FPGA, float>;
......
......@@ -21,7 +21,8 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) {
bool ConvAddBNReluKernel<FPGA, float>::Init(
FusionConvAddBNReluParam<FPGA> *param) {
bool relu_enabled = true;
const Tensor *input = param->Input();
auto input_ptr = input->data<half>();
......@@ -83,7 +84,7 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) {
template <>
void ConvAddBNReluKernel<FPGA, float>::Compute(
const FusionConvAddBNReluParam &param) const {
const FusionConvAddBNReluParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
template class ConvAddBNReluKernel<FPGA, float>;
......
......@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) {
bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam<FPGA> *param) {
bool relu_enabled = true;
const Tensor *input = param->Input();
auto input_ptr = input->data<half>();
......@@ -67,7 +67,7 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) {
template <>
void ConvAddReluKernel<FPGA, float>::Compute(
const FusionConvAddReluParam &param) const {
const FusionConvAddReluParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
template class ConvAddReluKernel<FPGA, float>;
......
......@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators {
template <>
bool ConvKernel<FPGA, float>::Init(ConvParam *param) {
bool ConvKernel<FPGA, float>::Init(ConvParam<FPGA> *param) {
return true;
}
template <>
void ConvKernel<FPGA, float>::Compute(const ConvParam &param) const {
void ConvKernel<FPGA, float>::Compute(const ConvParam<FPGA> &param) const {
// ConvCompute<float>(param);
}
......
......@@ -20,13 +20,14 @@ namespace paddle_mobile {
namespace operators {
template <>
bool DropoutKernel<FPGA, float>::Init(DropoutParam *param) {
bool DropoutKernel<FPGA, float>::Init(DropoutParam<FPGA> *param) {
param->Out()->ShareDataWith(*param->InputX());
return true;
}
template <>
void DropoutKernel<FPGA, float>::Compute(const DropoutParam &param) const {
void DropoutKernel<FPGA, float>::Compute(
const DropoutParam<FPGA> &param) const {
// auto *input_x = param.InputX();
// auto *out = param.Out();
// auto input_x_ptr = input_x->data<float>();
......
......@@ -20,7 +20,7 @@ namespace operators {
template <>
bool ElementwiseAddReluKernel<FPGA, float>::Init(
ElementwiseAddReluParam *param) {
ElementwiseAddReluParam<FPGA> *param) {
bool relu_enabled = true;
const Tensor *input_x = param->InputX();
const Tensor *input_y = param->InputY();
......@@ -57,7 +57,7 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init(
template <>
void ElementwiseAddReluKernel<FPGA, float>::Compute(
const ElementwiseAddReluParam &param) const {
const ElementwiseAddReluParam<FPGA> &param) const {
fpga::ComputeFpgaEWAdd(param.FpgaArgs());
}
} // namespace operators
......
......@@ -19,7 +19,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) {
bool relu_enabled = true;
const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<half>();
......@@ -66,7 +66,7 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
}
template <>
void FusionFcReluKernel<FPGA, float>::Compute(
const FusionFcReluParam &param) const {
const FusionFcReluParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs());
};
......
......@@ -19,7 +19,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
bool relu_enabled = false;
const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<half>();
......@@ -65,7 +65,8 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
}
template <>
void FusionFcKernel<FPGA, float>::Compute(const FusionFcParam &param) const {
void FusionFcKernel<FPGA, float>::Compute(
const FusionFcParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs());
}
} // namespace operators
......
......@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators {
template <>
bool PoolKernel<FPGA, float>::Init(PoolParam *param) {
bool PoolKernel<FPGA, float>::Init(PoolParam<FPGA> *param) {
const Tensor *input = param->Input();
auto input_ptr = input->data<half>();
Tensor *output = param->Output();
......@@ -46,7 +46,7 @@ bool PoolKernel<FPGA, float>::Init(PoolParam *param) {
}
template <>
void PoolKernel<FPGA, float>::Compute(const PoolParam &param) const {
void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) const {
#ifdef PADDLE_MOBILE_FPGA
fpga::ComputeFpgaPool(param.FpgaArgs());
#endif
......
......@@ -25,10 +25,10 @@ namespace operators {
template <typename DeviceType, typename T>
class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> {
: public framework::OpKernelBase<DeviceType, FusionFcParam<DeviceType>> {
public:
void Compute(const FusionFcParam& param) const;
bool Init(FusionFcParam* param);
void Compute(const FusionFcParam<DeviceType>& param) const;
bool Init(FusionFcParam<DeviceType>* param);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -27,10 +27,10 @@ namespace operators {
template <typename DeviceType, typename T>
class Im2SequenceKernel
: public framework::OpKernelBase<DeviceType, Im2SequenceParam> {
: public framework::OpKernelBase<DeviceType, Im2SequenceParam<DeviceType>> {
public:
void Compute(const Im2SequenceParam& param) const;
bool Init(Im2SequenceParam* para);
void Compute(const Im2SequenceParam<DeviceType>& param) const;
bool Init(Im2SequenceParam<DeviceType>* para);
};
} // namespace operators
} // namespace paddle_mobile
......
......@@ -170,10 +170,11 @@ struct LRNFunctor {
};
template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
class LrnKernel
: public framework::OpKernelBase<DeviceType, LrnParam<DeviceType>> {
public:
void Compute(const LrnParam &param) const;
bool Init(LrnParam *param);
void Compute(const LrnParam<DeviceType> &param) const;
bool Init(LrnParam<DeviceType> *param);
};
} // namespace operators
} // namespace paddle_mobile
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册