提交 0fd9a2ee 编写于 作者: N nhzlx

add template for op param

上级 d92d1c2f
...@@ -26,14 +26,15 @@ namespace operators { ...@@ -26,14 +26,15 @@ namespace operators {
using std::string; using std::string;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class BatchNormOp class BatchNormOp
: public framework::OperatorWithKernel<DeviceType, BatchNormParam, : public framework::OperatorWithKernel<DeviceType,
BatchNormParam<DeviceType>,
BatchNormKernel<DeviceType, T>> { BatchNormKernel<DeviceType, T>> {
public: public:
BatchNormOp(const string &type, const VariableNameMap &inputs, BatchNormOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, BatchNormParam, : framework::OperatorWithKernel<DeviceType, BatchNormParam<DeviceType>,
BatchNormKernel<DeviceType, T>>( BatchNormKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
......
...@@ -28,20 +28,20 @@ namespace operators { ...@@ -28,20 +28,20 @@ namespace operators {
using paddle_mobile::framework::Tensor; using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class BoxCoderOp class BoxCoderOp : public framework::OperatorWithKernel<
: public framework::OperatorWithKernel< DeviceType, BoxCoderParam<DeviceType>,
DeviceType, BoxCoderParam, operators::BoxCoderKernel<DeviceType, T>> { operators::BoxCoderKernel<DeviceType, T>> {
public: public:
BoxCoderOp(const std::string &type, const VariableNameMap &inputs, BoxCoderOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, BoxCoderParam, : framework::OperatorWithKernel<DeviceType, BoxCoderParam<DeviceType>,
operators::BoxCoderKernel<DeviceType, T>>( operators::BoxCoderKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, BoxCoderParam, DeviceType, BoxCoderParam<DeviceType>,
operators::BoxCoderKernel<DeviceType, T>>::OperatorWithKernel; operators::BoxCoderKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -24,19 +24,19 @@ namespace paddle_mobile { ...@@ -24,19 +24,19 @@ namespace paddle_mobile {
namespace operators { namespace operators {
using std::string; using std::string;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConcatOp class ConcatOp : public framework::OperatorWithKernel<
: public framework::OperatorWithKernel< DeviceType, ConcatParam<DeviceType>,
DeviceType, ConcatParam, operators::ConcatKernel<DeviceType, T>> { operators::ConcatKernel<DeviceType, T>> {
public: public:
ConcatOp(const string &type, const VariableNameMap &inputs, ConcatOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs, const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ConcatParam, : framework::OperatorWithKernel<DeviceType, ConcatParam<DeviceType>,
operators::ConcatKernel<DeviceType, T>>( operators::ConcatKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, ConcatParam, DeviceType, ConcatParam<DeviceType>,
operators::ConcatKernel<DeviceType, T>>::OperatorWithKernel; operators::ConcatKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -24,19 +24,19 @@ namespace paddle_mobile { ...@@ -24,19 +24,19 @@ namespace paddle_mobile {
namespace operators { namespace operators {
using std::string; using std::string;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvOp class ConvOp : public framework::OperatorWithKernel<
: public framework::OperatorWithKernel< DeviceType, ConvParam<DeviceType>,
DeviceType, ConvParam, operators::ConvKernel<DeviceType, T>> { operators::ConvKernel<DeviceType, T>> {
public: public:
ConvOp(const std::string &type, const VariableNameMap &inputs, ConvOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap &attrs, const VariableNameMap &outputs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, ConvParam, : framework::OperatorWithKernel<DeviceType, ConvParam<DeviceType>,
operators::ConvKernel<DeviceType, T>>( operators::ConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, ConvParam, DeviceType, ConvParam<DeviceType>,
operators::ConvKernel<DeviceType, T>>::OperatorWithKernel; operators::ConvKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -26,7 +26,7 @@ namespace paddle_mobile { ...@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvOpTranspose : public framework::OperatorWithKernel< class ConvOpTranspose : public framework::OperatorWithKernel<
DeviceType, ConvTransposeParam, DeviceType, ConvTransposeParam<DeviceType>,
operators::ConvTransposeKernel<DeviceType, T>> { operators::ConvTransposeKernel<DeviceType, T>> {
public: public:
ConvOpTranspose(const std::string &type, const VariableNameMap &inputs, ConvOpTranspose(const std::string &type, const VariableNameMap &inputs,
...@@ -34,7 +34,7 @@ class ConvOpTranspose : public framework::OperatorWithKernel< ...@@ -34,7 +34,7 @@ class ConvOpTranspose : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, ConvTransposeParam, DeviceType, ConvTransposeParam<DeviceType>,
operators::ConvTransposeKernel<DeviceType, T>>( operators::ConvTransposeKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
......
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DepthwiseConvOp : public framework::OperatorWithKernel< class DepthwiseConvOp : public framework::OperatorWithKernel<
DeviceType, ConvParam, DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>> { operators::DepthwiseConvKernel<DeviceType, T>> {
public: public:
DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs, DepthwiseConvOp(const std::string &type, const VariableNameMap &inputs,
...@@ -33,12 +33,12 @@ class DepthwiseConvOp : public framework::OperatorWithKernel< ...@@ -33,12 +33,12 @@ class DepthwiseConvOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, ConvParam, DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>>( operators::DepthwiseConvKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, ConvParam, DeviceType, ConvParam<DeviceType>,
operators::DepthwiseConvKernel<DeviceType, T>>::OperatorWithKernel; operators::DepthwiseConvKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -28,18 +28,18 @@ namespace operators { ...@@ -28,18 +28,18 @@ namespace operators {
using paddle_mobile::framework::Tensor; using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DropoutOp class DropoutOp : public framework::OperatorWithKernel<
: public framework::OperatorWithKernel< DeviceType, DropoutParam<DeviceType>,
DeviceType, DropoutParam, operators::DropoutKernel<DeviceType, T>> { operators::DropoutKernel<DeviceType, T>> {
public: public:
DropoutOp(const std::string &type, const VariableNameMap &inputs, DropoutOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs, const VariableNameMap &outputs, const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, DropoutParam, : framework::OperatorWithKernel<DeviceType, DropoutParam<DeviceType>,
operators::DropoutKernel<DeviceType, T>>( operators::DropoutKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
// using framework::OperatorWithKernel<DeviceType, DropoutParam, // using framework::OperatorWithKernel<DeviceType, DropoutParam<DeviceType>,
// operators::DropoutKernel<DeviceType, // operators::DropoutKernel<DeviceType,
// T>>; // T>>;
void InferShape() const override; void InferShape() const override;
......
...@@ -26,7 +26,7 @@ namespace operators { ...@@ -26,7 +26,7 @@ namespace operators {
using std::string; using std::string;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ElementwiseAddOp : public framework::OperatorWithKernel< class ElementwiseAddOp : public framework::OperatorWithKernel<
DeviceType, ElementwiseAddParam, DeviceType, ElementwiseAddParam<DeviceType>,
operators::ElementwiseAddKernel<DeviceType, T>> { operators::ElementwiseAddKernel<DeviceType, T>> {
public: public:
ElementwiseAddOp(const string &type, const VariableNameMap &inputs, ElementwiseAddOp(const string &type, const VariableNameMap &inputs,
...@@ -34,12 +34,12 @@ class ElementwiseAddOp : public framework::OperatorWithKernel< ...@@ -34,12 +34,12 @@ class ElementwiseAddOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, ElementwiseAddParam, DeviceType, ElementwiseAddParam<DeviceType>,
operators::ElementwiseAddKernel<DeviceType, T>>( operators::ElementwiseAddKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, ElementwiseAddParam, DeviceType, ElementwiseAddParam<DeviceType>,
operators::ElementwiseAddKernel<DeviceType, T>>::OperatorWithKernel; operators::ElementwiseAddKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -61,7 +61,7 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -61,7 +61,7 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
#endif #endif
protected: protected:
FeedParam param_; FeedParam<DeviceType> param_;
}; };
} // namespace operators } // namespace operators
......
...@@ -41,7 +41,7 @@ class FetchOp : public framework::OperatorBase<DeviceType> { ...@@ -41,7 +41,7 @@ class FetchOp : public framework::OperatorBase<DeviceType> {
} }
protected: protected:
FetchParam param_; FetchParam<DeviceType> param_;
}; };
} // namespace operators } // namespace operators
......
...@@ -45,19 +45,20 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher { ...@@ -45,19 +45,20 @@ class FusionConvAddMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionConvAddOp : public framework::OperatorWithKernel< class FusionConvAddOp : public framework::OperatorWithKernel<
DeviceType, FusionConvAddParam, DeviceType, FusionConvAddParam<DeviceType>,
operators::ConvAddKernel<DeviceType, T>> { operators::ConvAddKernel<DeviceType, T>> {
public: public:
FusionConvAddOp(const string &type, const VariableNameMap &inputs, FusionConvAddOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, FusionConvAddParam, : framework::OperatorWithKernel<DeviceType,
FusionConvAddParam<DeviceType>,
operators::ConvAddKernel<DeviceType, T>>( operators::ConvAddKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionConvAddParam, DeviceType, FusionConvAddParam<DeviceType>,
operators::ConvAddKernel<DeviceType, T>>::OperatorWithKernel; operators::ConvAddKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -53,7 +53,7 @@ class FusionConvAddBNMatcher : public framework::FusionOpMatcher { ...@@ -53,7 +53,7 @@ class FusionConvAddBNMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionConvAddBNOp : public framework::OperatorWithKernel< class FusionConvAddBNOp : public framework::OperatorWithKernel<
DeviceType, FusionConvAddBNParam, DeviceType, FusionConvAddBNParam<DeviceType>,
operators::ConvAddBNKernel<DeviceType, T>> { operators::ConvAddBNKernel<DeviceType, T>> {
public: public:
FusionConvAddBNOp(const string &type, const VariableNameMap &inputs, FusionConvAddBNOp(const string &type, const VariableNameMap &inputs,
...@@ -61,7 +61,7 @@ class FusionConvAddBNOp : public framework::OperatorWithKernel< ...@@ -61,7 +61,7 @@ class FusionConvAddBNOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, FusionConvAddBNParam, DeviceType, FusionConvAddBNParam<DeviceType>,
operators::ConvAddBNKernel<DeviceType, T>>(type, inputs, outputs, operators::ConvAddBNKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {} attrs, scope) {}
......
...@@ -55,7 +55,7 @@ class FusionConvAddBNReluMatcher : public framework::FusionOpMatcher { ...@@ -55,7 +55,7 @@ class FusionConvAddBNReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionConvAddBNReluOp class FusionConvAddBNReluOp
: public framework::OperatorWithKernel< : public framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam, DeviceType, FusionConvAddBNReluParam<DeviceType>,
operators::ConvAddBNReluKernel<DeviceType, T>> { operators::ConvAddBNReluKernel<DeviceType, T>> {
public: public:
FusionConvAddBNReluOp(const string &type, const VariableNameMap &inputs, FusionConvAddBNReluOp(const string &type, const VariableNameMap &inputs,
...@@ -63,12 +63,12 @@ class FusionConvAddBNReluOp ...@@ -63,12 +63,12 @@ class FusionConvAddBNReluOp
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam, DeviceType, FusionConvAddBNReluParam<DeviceType>,
operators::ConvAddBNReluKernel<DeviceType, T>>( operators::ConvAddBNReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionConvAddBNReluParam, DeviceType, FusionConvAddBNReluParam<DeviceType>,
operators::ConvAddBNReluKernel<DeviceType, T>>::OperatorWithKernel; operators::ConvAddBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -43,7 +43,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher { ...@@ -43,7 +43,7 @@ class FusionConvAddReluOpMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionConvAddReluOp : public framework::OperatorWithKernel< class FusionConvAddReluOp : public framework::OperatorWithKernel<
DeviceType, FusionConvAddReluParam, DeviceType, FusionConvAddReluParam<DeviceType>,
operators::ConvAddReluKernel<DeviceType, T>> { operators::ConvAddReluKernel<DeviceType, T>> {
public: public:
FusionConvAddReluOp(const string &type, const VariableNameMap &inputs, FusionConvAddReluOp(const string &type, const VariableNameMap &inputs,
...@@ -51,12 +51,12 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel< ...@@ -51,12 +51,12 @@ class FusionConvAddReluOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, FusionConvAddReluParam, DeviceType, FusionConvAddReluParam<DeviceType>,
operators::ConvAddReluKernel<DeviceType, T>>(type, inputs, outputs, operators::ConvAddReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {} attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionConvAddReluParam, DeviceType, FusionConvAddReluParam<DeviceType>,
operators::ConvAddReluKernel<DeviceType, T>>::OperatorWithKernel; operators::ConvAddReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -52,7 +52,7 @@ class FusionConvBNReluMatcher : public framework::FusionOpMatcher { ...@@ -52,7 +52,7 @@ class FusionConvBNReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionConvBNReluOp : public framework::OperatorWithKernel< class FusionConvBNReluOp : public framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam, DeviceType, FusionConvBNReluParam<DeviceType>,
operators::ConvBNReluKernel<DeviceType, T>> { operators::ConvBNReluKernel<DeviceType, T>> {
public: public:
FusionConvBNReluOp(const string &type, const VariableNameMap &inputs, FusionConvBNReluOp(const string &type, const VariableNameMap &inputs,
...@@ -60,12 +60,12 @@ class FusionConvBNReluOp : public framework::OperatorWithKernel< ...@@ -60,12 +60,12 @@ class FusionConvBNReluOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam, DeviceType, FusionConvBNReluParam<DeviceType>,
operators::ConvBNReluKernel<DeviceType, T>>(type, inputs, outputs, operators::ConvBNReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {} attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionConvBNReluParam, DeviceType, FusionConvBNReluParam<DeviceType>,
operators::ConvBNReluKernel<DeviceType, T>>::OperatorWithKernel; operators::ConvBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -51,21 +51,22 @@ class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher { ...@@ -51,21 +51,22 @@ class FusionDWConvBNReluMatcher : public framework::FusionOpMatcher {
}; };
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionDWConvBNReluOp : public framework::OperatorWithKernel< class FusionDWConvBNReluOp
DeviceType, FusionDWConvBNReluParam, : public framework::OperatorWithKernel<
operators::DWConvBNReluKernel<DeviceType, T>> { DeviceType, FusionDWConvBNReluParam<DeviceType>,
operators::DWConvBNReluKernel<DeviceType, T>> {
public: public:
FusionDWConvBNReluOp(const string &type, const VariableNameMap &inputs, FusionDWConvBNReluOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam, DeviceType, FusionDWConvBNReluParam<DeviceType>,
operators::DWConvBNReluKernel<DeviceType, T>>(type, inputs, outputs, operators::DWConvBNReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {} attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionDWConvBNReluParam, DeviceType, FusionDWConvBNReluParam<DeviceType>,
operators::DWConvBNReluKernel<DeviceType, T>>::OperatorWithKernel; operators::DWConvBNReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -44,7 +44,7 @@ class FusioneElementwiseAddReluMatcher : public framework::FusionOpMatcher { ...@@ -44,7 +44,7 @@ class FusioneElementwiseAddReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionElementwiseAddReluOp class FusionElementwiseAddReluOp
: public framework::OperatorWithKernel< : public framework::OperatorWithKernel<
DeviceType, ElementwiseAddReluParam, DeviceType, ElementwiseAddReluParam<DeviceType>,
operators::ElementwiseAddReluKernel<DeviceType, T>> { operators::ElementwiseAddReluKernel<DeviceType, T>> {
public: public:
FusionElementwiseAddReluOp(const string &type, const VariableNameMap &inputs, FusionElementwiseAddReluOp(const string &type, const VariableNameMap &inputs,
...@@ -52,7 +52,7 @@ class FusionElementwiseAddReluOp ...@@ -52,7 +52,7 @@ class FusionElementwiseAddReluOp
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, ElementwiseAddReluParam, DeviceType, ElementwiseAddReluParam<DeviceType>,
operators::ElementwiseAddReluKernel<DeviceType, T>>( operators::ElementwiseAddReluKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
......
...@@ -45,20 +45,20 @@ class FusionFcMatcher : public framework::FusionOpMatcher { ...@@ -45,20 +45,20 @@ class FusionFcMatcher : public framework::FusionOpMatcher {
}; };
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionFcOp class FusionFcOp : public framework::OperatorWithKernel<
: public framework::OperatorWithKernel< DeviceType, FusionFcParam<DeviceType>,
DeviceType, FusionFcParam, operators::FusionFcKernel<DeviceType, T>> { operators::FusionFcKernel<DeviceType, T>> {
public: public:
FusionFcOp(const string &type, const VariableNameMap &inputs, FusionFcOp(const string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType, FusionFcParam, : framework::OperatorWithKernel<DeviceType, FusionFcParam<DeviceType>,
operators::FusionFcKernel<DeviceType, T>>( operators::FusionFcKernel<DeviceType, T>>(
type, inputs, outputs, attrs, scope) {} type, inputs, outputs, attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionFcParam, DeviceType, FusionFcParam<DeviceType>,
operators::FusionFcKernel<DeviceType, T>>::OperatorWithKernel; operators::FusionFcKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -44,7 +44,7 @@ class FusionFcReluMatcher : public framework::FusionOpMatcher { ...@@ -44,7 +44,7 @@ class FusionFcReluMatcher : public framework::FusionOpMatcher {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionFcReluOp : public framework::OperatorWithKernel< class FusionFcReluOp : public framework::OperatorWithKernel<
DeviceType, FusionFcReluParam, DeviceType, FusionFcReluParam<DeviceType>,
operators::FusionFcReluKernel<DeviceType, T>> { operators::FusionFcReluKernel<DeviceType, T>> {
public: public:
FusionFcReluOp(const string &type, const VariableNameMap &inputs, FusionFcReluOp(const string &type, const VariableNameMap &inputs,
...@@ -52,12 +52,12 @@ class FusionFcReluOp : public framework::OperatorWithKernel< ...@@ -52,12 +52,12 @@ class FusionFcReluOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, FusionFcReluParam, DeviceType, FusionFcReluParam<DeviceType>,
operators::FusionFcReluKernel<DeviceType, T>>(type, inputs, outputs, operators::FusionFcReluKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {} attrs, scope) {}
using framework::OperatorWithKernel< using framework::OperatorWithKernel<
DeviceType, FusionFcReluParam, DeviceType, FusionFcReluParam<DeviceType>,
operators::FusionFcReluKernel<DeviceType, T>>::OperatorWithKernel; operators::FusionFcReluKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -27,7 +27,7 @@ using namespace framework; ...@@ -27,7 +27,7 @@ using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class Im2SequenceOp : public framework::OperatorWithKernel< class Im2SequenceOp : public framework::OperatorWithKernel<
DeviceType, Im2SequenceParam, DeviceType, Im2SequenceParam<DeviceType>,
operators::Im2SequenceKernel<DeviceType, T>> { operators::Im2SequenceKernel<DeviceType, T>> {
public: public:
Im2SequenceOp(const std::string &type, const VariableNameMap &inputs, Im2SequenceOp(const std::string &type, const VariableNameMap &inputs,
...@@ -35,12 +35,12 @@ class Im2SequenceOp : public framework::OperatorWithKernel< ...@@ -35,12 +35,12 @@ class Im2SequenceOp : public framework::OperatorWithKernel<
const framework::AttributeMap &attrs, const framework::AttributeMap &attrs,
std::shared_ptr<framework::Scope> scope) std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel< : framework::OperatorWithKernel<
DeviceType, Im2SequenceParam, DeviceType, Im2SequenceParam<DeviceType>,
operators::Im2SequenceKernel<DeviceType, T>>(type, inputs, outputs, operators::Im2SequenceKernel<DeviceType, T>>(type, inputs, outputs,
attrs, scope) {} attrs, scope) {}
// using framework::OperatorWithKernel< // using framework::OperatorWithKernel<
// DeviceType, Im2SequenceParam, // DeviceType, Im2SequenceParam<DeviceType>,
// operators::Im2SequenceKernel<DeviceType, T>>::OperatorWithKernel; // operators::Im2SequenceKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
......
...@@ -21,12 +21,13 @@ namespace paddle_mobile { ...@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool BatchNormKernel<CPU, float>::Init(BatchNormParam *param) { bool BatchNormKernel<CPU, float>::Init(BatchNormParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const { void BatchNormKernel<CPU, float>::Compute(
const BatchNormParam<CPU> &param) const {
BatchnormCompute<float>(param); BatchnormCompute<float>(param);
} }
......
...@@ -21,12 +21,13 @@ namespace paddle_mobile { ...@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam *param) { bool BoxCoderKernel<CPU, float>::Init(BoxCoderParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam &param) const { void BoxCoderKernel<CPU, float>::Compute(
const BoxCoderParam<CPU> &param) const {
BoxCoderCompute<float>(param); BoxCoderCompute<float>(param);
} }
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConcatKernel<CPU, float>::Init(ConcatParam *param) { bool ConcatKernel<CPU, float>::Init(ConcatParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const { void ConcatKernel<CPU, float>::Compute(const ConcatParam<CPU> &param) const {
ConcatCompute<float>(param); ConcatCompute<float>(param);
} }
......
...@@ -21,7 +21,8 @@ namespace paddle_mobile { ...@@ -21,7 +21,8 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) { bool ConvAddBNReluKernel<CPU, float>::Init(
FusionConvAddBNReluParam<CPU> *param) {
const Tensor *mean = param->InputMean(); const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance(); const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale(); const Tensor *scale = param->InputScale();
...@@ -54,7 +55,7 @@ bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) { ...@@ -54,7 +55,7 @@ bool ConvAddBNReluKernel<CPU, float>::Init(FusionConvAddBNReluParam *param) {
template <> template <>
void ConvAddBNReluKernel<CPU, float>::Compute( void ConvAddBNReluKernel<CPU, float>::Compute(
const FusionConvAddBNReluParam &param) const { const FusionConvAddBNReluParam<CPU> &param) const {
ConvAddBNReluCompute<float>(param); ConvAddBNReluCompute<float>(param);
} }
template class ConvAddBNReluKernel<CPU, float>; template class ConvAddBNReluKernel<CPU, float>;
......
...@@ -20,12 +20,13 @@ namespace paddle_mobile { ...@@ -20,12 +20,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam *param) { bool ConvAddKernel<CPU, float>::Init(FusionConvAddParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const { void ConvAddKernel<CPU, float>::Compute(
const FusionConvAddParam<CPU> &param) const {
ConvAddCompute<float>(param); ConvAddCompute<float>(param);
} }
......
...@@ -21,13 +21,13 @@ namespace paddle_mobile { ...@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam *param) { bool ConvAddReluKernel<CPU, float>::Init(FusionConvAddReluParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ConvAddReluKernel<CPU, float>::Compute( void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam &param) const { const FusionConvAddReluParam<CPU> &param) const {
ConvAddReluCompute<float>(param); ConvAddReluCompute<float>(param);
} }
template class ConvAddReluKernel<CPU, float>; template class ConvAddReluKernel<CPU, float>;
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam *param) { bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) {
const Tensor *mean = param->InputMean(); const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance(); const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale(); const Tensor *scale = param->InputScale();
...@@ -57,7 +57,7 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam *param) { ...@@ -57,7 +57,7 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam *param) {
template <> template <>
void ConvBNReluKernel<CPU, float>::Compute( void ConvBNReluKernel<CPU, float>::Compute(
const FusionConvBNReluParam &param) const { const FusionConvBNReluParam<CPU> &param) const {
ConvBNReluCompute<float>(param); ConvBNReluCompute<float>(param);
} }
template class ConvBNReluKernel<CPU, float>; template class ConvBNReluKernel<CPU, float>;
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvKernel<CPU, float>::Init(ConvParam *param) { bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ConvKernel<CPU, float>::Compute(const ConvParam &param) const { void ConvKernel<CPU, float>::Compute(const ConvParam<CPU> &param) const {
ConvCompute<float>(param); ConvCompute<float>(param);
} }
......
...@@ -21,13 +21,13 @@ namespace paddle_mobile { ...@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvTransposeKernel<CPU, float>::Init(ConvTransposeParam *param) { bool ConvTransposeKernel<CPU, float>::Init(ConvTransposeParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ConvTransposeKernel<CPU, float>::Compute( void ConvTransposeKernel<CPU, float>::Compute(
const ConvTransposeParam &param) const { const ConvTransposeParam<CPU> &param) const {
ConvTransposeCompute<float>(param); ConvTransposeCompute<float>(param);
} }
......
...@@ -21,12 +21,13 @@ namespace paddle_mobile { ...@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool DepthwiseConvKernel<CPU, float>::Init(ConvParam *param) { bool DepthwiseConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { void DepthwiseConvKernel<CPU, float>::Compute(
const ConvParam<CPU> &param) const {
DepthwiseConvCompute<float>(param); DepthwiseConvCompute<float>(param);
} }
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool DropoutKernel<CPU, float>::Init(DropoutParam *para) { bool DropoutKernel<CPU, float>::Init(DropoutParam<CPU> *para) {
return true; return true;
} }
...@@ -31,7 +31,7 @@ struct DropoutFunctor { ...@@ -31,7 +31,7 @@ struct DropoutFunctor {
}; };
template <> template <>
void DropoutKernel<CPU, float>::Compute(const DropoutParam &param) const { void DropoutKernel<CPU, float>::Compute(const DropoutParam<CPU> &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>(); auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out(); auto *out = param.Out();
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam *param) { bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam<CPU> *param) {
const Tensor *mean = param->InputMean(); const Tensor *mean = param->InputMean();
const Tensor *variance = param->InputVariance(); const Tensor *variance = param->InputVariance();
const Tensor *scale = param->InputScale(); const Tensor *scale = param->InputScale();
...@@ -54,7 +54,7 @@ bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam *param) { ...@@ -54,7 +54,7 @@ bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam *param) {
template <> template <>
void DWConvBNReluKernel<CPU, float>::Compute( void DWConvBNReluKernel<CPU, float>::Compute(
const FusionDWConvBNReluParam &param) const { const FusionDWConvBNReluParam<CPU> &param) const {
DWConvBNReluCompute<float>(param); DWConvBNReluCompute<float>(param);
} }
template class DWConvBNReluKernel<CPU, float>; template class DWConvBNReluKernel<CPU, float>;
......
...@@ -21,13 +21,13 @@ namespace paddle_mobile { ...@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam *param) { bool ElementwiseAddKernel<CPU, float>::Init(ElementwiseAddParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ElementwiseAddKernel<CPU, float>::Compute( void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam &param) const { const ElementwiseAddParam<CPU> &param) const {
ElementwiseAddCompute<float>(param); ElementwiseAddCompute<float>(param);
} }
......
...@@ -21,12 +21,13 @@ namespace paddle_mobile { ...@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FusionFcKernel<CPU, float>::Init(FusionFcParam *param) { bool FusionFcKernel<CPU, float>::Init(FusionFcParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const { void FusionFcKernel<CPU, float>::Compute(
const FusionFcParam<CPU> &param) const {
FusionFcCompute<float>(param); FusionFcCompute<float>(param);
} }
......
...@@ -20,7 +20,7 @@ namespace paddle_mobile { ...@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool Im2SequenceKernel<CPU, float>::Init(Im2SequenceParam *para) { bool Im2SequenceKernel<CPU, float>::Init(Im2SequenceParam<CPU> *para) {
return true; return true;
} }
...@@ -33,7 +33,7 @@ inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0, ...@@ -33,7 +33,7 @@ inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0,
template <> template <>
void Im2SequenceKernel<CPU, float>::Compute( void Im2SequenceKernel<CPU, float>::Compute(
const Im2SequenceParam &param) const { const Im2SequenceParam<CPU> &param) const {
const Tensor *in_x = param.Input(); const Tensor *in_x = param.Input();
Tensor *out = param.Output(); Tensor *out = param.Output();
out->mutable_data<float>(); out->mutable_data<float>();
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool LrnKernel<CPU, float>::Init(LrnParam *param) { bool LrnKernel<CPU, float>::Init(LrnParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const { void LrnKernel<CPU, float>::Compute(const LrnParam<CPU> &param) const {
LrnCompute<float>(param); LrnCompute<float>(param);
} }
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool MulKernel<CPU, float>::Init(MulParam *param) { bool MulKernel<CPU, float>::Init(MulParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void MulKernel<CPU, float>::Compute(const MulParam &param) const { void MulKernel<CPU, float>::Compute(const MulParam<CPU> &param) const {
MulCompute<float>(param); MulCompute<float>(param);
} }
......
...@@ -21,13 +21,13 @@ namespace paddle_mobile { ...@@ -21,13 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam *param) { bool MultiClassNMSKernel<CPU, float>::Init(MultiClassNMSParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void MultiClassNMSKernel<CPU, float>::Compute( void MultiClassNMSKernel<CPU, float>::Compute(
const MultiClassNMSParam &param) const { const MultiClassNMSParam<CPU> &param) const {
MultiClassNMSCompute<float>(param); MultiClassNMSCompute<float>(param);
} }
......
...@@ -20,12 +20,12 @@ namespace paddle_mobile { ...@@ -20,12 +20,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool PoolKernel<CPU, float>::Init(PoolParam *param) { bool PoolKernel<CPU, float>::Init(PoolParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void PoolKernel<CPU, float>::Compute(const PoolParam &param) const { void PoolKernel<CPU, float>::Compute(const PoolParam<CPU> &param) const {
PoolCompute<float>(param); PoolCompute<float>(param);
} }
} // namespace operators } // namespace operators
......
...@@ -32,7 +32,7 @@ struct PReluFunctor { ...@@ -32,7 +32,7 @@ struct PReluFunctor {
* @b 特化到具体平台的实现, param 从 op 层传入 * @b 特化到具体平台的实现, param 从 op 层传入
* */ * */
template <> template <>
void PReluKernel<CPU, float>::Compute(const PReluParam &param) const { void PReluKernel<CPU, float>::Compute(const PReluParam<CPU> &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>(); auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out(); auto *out = param.Out();
......
...@@ -21,12 +21,13 @@ namespace paddle_mobile { ...@@ -21,12 +21,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam *param) { bool PriorBoxKernel<CPU, float>::Init(PriorBoxParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const { void PriorBoxKernel<CPU, float>::Compute(
const PriorBoxParam<CPU> &param) const {
PriorBoxCompute<float>(param); PriorBoxCompute<float>(param);
} }
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ReluKernel<CPU, float>::Init(ReluParam *param) { bool ReluKernel<CPU, float>::Init(ReluParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ReluKernel<CPU, float>::Compute(const ReluParam &param) const { void ReluKernel<CPU, float>::Compute(const ReluParam<CPU> &param) const {
ReluCompute<float>(param); ReluCompute<float>(param);
} }
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ReshapeKernel<CPU, float>::Init(ReshapeParam *param) { bool ReshapeKernel<CPU, float>::Init(ReshapeParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const { void ReshapeKernel<CPU, float>::Compute(const ReshapeParam<CPU> &param) const {
ReshapeCompute<float>(param); ReshapeCompute<float>(param);
} }
......
...@@ -108,7 +108,7 @@ void ResizeTensor(const Tensor* src, Tensor* dst) { ...@@ -108,7 +108,7 @@ void ResizeTensor(const Tensor* src, Tensor* dst) {
} }
template <> template <>
void ResizeKernel<CPU, float>::Compute(const ResizeParam& param) const { void ResizeKernel<CPU, float>::Compute(const ResizeParam<CPU>& param) const {
const auto* input_x = param.InputX(); const auto* input_x = param.InputX();
const auto& input_x_dims = input_x->dims(); const auto& input_x_dims = input_x->dims();
auto* out = param.Out(); auto* out = param.Out();
......
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
* @b 特化到具体平台的实现, param 从 op 层传入 * @b 特化到具体平台的实现, param 从 op 层传入
* */ * */
template <> template <>
void ScaleKernel<CPU, float>::Compute(const ScaleParam &param) const { void ScaleKernel<CPU, float>::Compute(const ScaleParam<CPU> &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>(); auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out(); auto *out = param.Out();
......
...@@ -27,12 +27,12 @@ using framework::DDim; ...@@ -27,12 +27,12 @@ using framework::DDim;
using framework::Tensor; using framework::Tensor;
template <> template <>
bool SigmoidKernel<CPU, float>::Init(SigmoidParam *param) { bool SigmoidKernel<CPU, float>::Init(SigmoidParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const { void SigmoidKernel<CPU, float>::Compute(const SigmoidParam<CPU> &param) const {
SigmoidCompute<float>(param); SigmoidCompute<float>(param);
} }
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam *param) { bool SoftmaxKernel<CPU, float>::Init(SoftmaxParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const { void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam<CPU> &param) const {
SoftmaxCompute<float>(param); SoftmaxCompute<float>(param);
} }
......
...@@ -20,12 +20,13 @@ namespace paddle_mobile { ...@@ -20,12 +20,13 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool TransposeKernel<CPU, float>::Init(TransposeParam *param) { bool TransposeKernel<CPU, float>::Init(TransposeParam<CPU> *param) {
return true; return true;
} }
template <> template <>
void TransposeKernel<CPU, float>::Compute(const TransposeParam &param) const { void TransposeKernel<CPU, float>::Compute(
const TransposeParam<CPU> &param) const {
TransposeCompute<float>(param); TransposeCompute<float>(param);
} }
......
...@@ -26,10 +26,10 @@ using namespace framework; ...@@ -26,10 +26,10 @@ using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class BatchNormKernel class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> { : public framework::OpKernelBase<DeviceType, BatchNormParam<DeviceType>> {
public: public:
void Compute(const BatchNormParam &param) const; void Compute(const BatchNormParam<DeviceType> &param) const;
bool Init(BatchNormParam *param); bool Init(BatchNormParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -27,10 +27,10 @@ namespace operators { ...@@ -27,10 +27,10 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class BoxCoderKernel class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> { : public framework::OpKernelBase<DeviceType, BoxCoderParam<DeviceType>> {
public: public:
void Compute(const BoxCoderParam& param) const; void Compute(const BoxCoderParam<DeviceType>& param) const;
bool Init(BoxCoderParam* param); bool Init(BoxCoderParam<DeviceType>* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -23,7 +23,7 @@ namespace paddle_mobile { ...@@ -23,7 +23,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void BatchnormCompute(const BatchNormParam &param) { void BatchnormCompute(const BatchNormParam<CPU> &param) {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
auto input_x_ptr = input_x->data<float>(); auto input_x_ptr = input_x->data<float>();
const auto &x_dims = input_x->dims(); const auto &x_dims = input_x->dims();
......
...@@ -113,7 +113,7 @@ void DecodeCenterSize(const framework::Tensor& target_box, ...@@ -113,7 +113,7 @@ void DecodeCenterSize(const framework::Tensor& target_box,
} }
template <typename P> template <typename P>
void BoxCoderCompute(const BoxCoderParam& param) { void BoxCoderCompute(const BoxCoderParam<CPU>& param) {
const auto* input_priorbox = param.InputPriorBox(); const auto* input_priorbox = param.InputPriorBox();
const auto* input_priorboxvar = param.InputPriorBoxVar(); const auto* input_priorboxvar = param.InputPriorBoxVar();
const auto* input_targetbox = param.InputTargetBox(); const auto* input_targetbox = param.InputTargetBox();
......
...@@ -54,7 +54,7 @@ class ConcatFunctor { ...@@ -54,7 +54,7 @@ class ConcatFunctor {
}; };
template <typename P> template <typename P>
void ConcatCompute(const ConcatParam &param) { void ConcatCompute(const ConcatParam<CPU> &param) {
auto inputs = param.Inputs(); auto inputs = param.Inputs();
auto *out = param.Out(); auto *out = param.Out();
int64_t axis = param.Axis(); int64_t axis = param.Axis();
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void ConvAddBasic(const FusionConvAddParam &param) { void ConvAddBasic(const FusionConvAddParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
...@@ -114,7 +114,7 @@ void ConvAddBasic(const FusionConvAddParam &param) { ...@@ -114,7 +114,7 @@ void ConvAddBasic(const FusionConvAddParam &param) {
} }
template <typename P> template <typename P>
void ConvAddCompute(const FusionConvAddParam &param) { void ConvAddCompute(const FusionConvAddParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) { void ConvAddBNReluBasic(const FusionConvAddBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
...@@ -112,7 +112,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) { ...@@ -112,7 +112,7 @@ void ConvAddBNReluBasic(const FusionConvAddBNReluParam &param) {
} }
} }
template <typename P> template <typename P>
void ConvAddBNReluCompute(const FusionConvAddBNReluParam &param) { void ConvAddBNReluCompute(const FusionConvAddBNReluParam<CPU> &param) {
Tensor Bias; Tensor Bias;
Bias.mutable_data<float>({param.Groups()}); Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
......
...@@ -26,7 +26,7 @@ namespace paddle_mobile { ...@@ -26,7 +26,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void ConvAddReluCompute(const FusionConvAddReluParam &param) { void ConvAddReluCompute(const FusionConvAddReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor bias = *param.Bias(); Tensor bias = *param.Bias();
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
inline void ConvBasic(const ConvParam &param) { inline void ConvBasic(const ConvParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor *output = param.Output(); Tensor *output = param.Output();
...@@ -112,7 +112,7 @@ inline void ConvBasic(const ConvParam &param) { ...@@ -112,7 +112,7 @@ inline void ConvBasic(const ConvParam &param) {
} }
template <typename P> template <typename P>
void ConvCompute(const ConvParam &param) { void ConvCompute(const ConvParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void ConvBNReluBasic(const FusionConvBNReluParam &param) { void ConvBNReluBasic(const FusionConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
...@@ -113,7 +113,7 @@ void ConvBNReluBasic(const FusionConvBNReluParam &param) { ...@@ -113,7 +113,7 @@ void ConvBNReluBasic(const FusionConvBNReluParam &param) {
} }
template <typename P> template <typename P>
void ConvBNReluCompute(const FusionConvBNReluParam &param) { void ConvBNReluCompute(const FusionConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
...@@ -28,7 +28,7 @@ namespace paddle_mobile { ...@@ -28,7 +28,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void ConvTransposeCompute(const ConvTransposeParam &param) { void ConvTransposeCompute(const ConvTransposeParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor *output = param.Output(); Tensor *output = param.Output();
......
...@@ -25,7 +25,7 @@ namespace paddle_mobile { ...@@ -25,7 +25,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void DepthwiseConvCompute(const ConvParam &param) { void DepthwiseConvCompute(const ConvParam<CPU> &param) {
Tensor Bias; Tensor Bias;
Bias.mutable_data<float>({param.Groups()}); Bias.mutable_data<float>({param.Groups()});
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
......
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
#include "operators/op_param.h" #include "operators/op_param.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) { void DWConvBNReluBasic(const FusionDWConvBNReluParam<CPU> &param) {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
Tensor filter = *param.Filter(); Tensor filter = *param.Filter();
Tensor new_bias = *param.NewBias(); Tensor new_bias = *param.NewBias();
...@@ -111,7 +111,7 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) { ...@@ -111,7 +111,7 @@ void DWConvBNReluBasic(const FusionDWConvBNReluParam &param) {
} }
} }
template <typename P> template <typename P>
void DWConvBNReluCompute(const FusionDWConvBNReluParam &param) { void DWConvBNReluCompute(const FusionDWConvBNReluParam<CPU> &param) {
if (param.Groups() == param.Input()->dims()[1] && if (param.Groups() == param.Input()->dims()[1] &&
param.Input()->dims()[1] == param.Output()->dims()[1] && param.Input()->dims()[1] == param.Output()->dims()[1] &&
param.Filter()->dims()[2] == param.Filter()->dims()[3] && param.Filter()->dims()[2] == param.Filter()->dims()[3] &&
......
...@@ -27,7 +27,7 @@ struct AddFunctor { ...@@ -27,7 +27,7 @@ struct AddFunctor {
}; };
template <typename P> template <typename P>
void ElementwiseAddCompute(const ElementwiseAddParam &param) { void ElementwiseAddCompute(const ElementwiseAddParam<CPU> &param) {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY(); const Tensor *input_y = param.InputY();
Tensor *Out = param.Out(); Tensor *Out = param.Out();
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void FusionFcCompute(const FusionFcParam &param) { void FusionFcCompute(const FusionFcParam<CPU> &param) {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY(); const Tensor *input_y = param.InputY();
const Tensor *input_z = param.InputZ(); const Tensor *input_z = param.InputZ();
......
...@@ -20,7 +20,7 @@ namespace paddle_mobile { ...@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void LrnCompute(const LrnParam &param) { void LrnCompute(const LrnParam<CPU> &param) {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
auto x_dims = input_x->dims(); auto x_dims = input_x->dims();
Tensor *out = param.Out(); Tensor *out = param.Out();
......
...@@ -54,7 +54,7 @@ namespace operators { ...@@ -54,7 +54,7 @@ namespace operators {
// 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列) // 结果x(6行4列)乘y(4行2列),按1中矩阵相乘,结果out(6行2列)
template <typename P> template <typename P>
void MulCompute(const MulParam &param) { void MulCompute(const MulParam<CPU> &param) {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
const Tensor *input_y = param.InputY(); const Tensor *input_y = param.InputY();
Tensor *out = param.Out(); Tensor *out = param.Out();
......
...@@ -213,7 +213,7 @@ void MultiClassOutput(const framework::Tensor& scores, ...@@ -213,7 +213,7 @@ void MultiClassOutput(const framework::Tensor& scores,
} }
template <typename P> template <typename P>
void MultiClassNMSCompute(const MultiClassNMSParam& param) { void MultiClassNMSCompute(const MultiClassNMSParam<CPU>& param) {
const auto* input_bboxes = param.InputBBoxes(); const auto* input_bboxes = param.InputBBoxes();
const auto& input_bboxes_dims = input_bboxes->dims(); const auto& input_bboxes_dims = input_bboxes->dims();
......
...@@ -38,7 +38,7 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize, ...@@ -38,7 +38,7 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
} }
} }
template <typename P> template <typename P>
void PoolCompute(const PoolParam &param) { void PoolCompute(const PoolParam<CPU> &param) {
const Tensor *in_x = param.Input(); const Tensor *in_x = param.Input();
Tensor *out = param.Output(); Tensor *out = param.Output();
std::string pooling_type = param.PoolingType(); std::string pooling_type = param.PoolingType();
...@@ -58,7 +58,8 @@ void PoolCompute(const PoolParam &param) { ...@@ -58,7 +58,8 @@ void PoolCompute(const PoolParam &param) {
paddings[i] = 0; paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
} else if (ksize[0] == 3 && ksize[0] == ksize[1]) { }
if (ksize[0] == 3 && ksize[0] == ksize[1]) {
if (pooling_type == "max") { if (pooling_type == "max") {
if (strides[0] == strides[1] && strides[0] == 1 && if (strides[0] == strides[1] && strides[0] == 1 &&
paddings[0] == paddings[1] && paddings[1] == 1) { paddings[0] == paddings[1] && paddings[1] == 1) {
......
...@@ -29,7 +29,7 @@ struct ClipFunctor { ...@@ -29,7 +29,7 @@ struct ClipFunctor {
}; };
template <typename P> template <typename P>
void PriorBoxCompute(const PriorBoxParam &param) { void PriorBoxCompute(const PriorBoxParam<CPU> &param) {
const auto *input_ = param.Input(); const auto *input_ = param.Input();
const auto &input_dims = input_->dims(); const auto &input_dims = input_->dims();
......
...@@ -30,7 +30,7 @@ struct ReluFunctor { ...@@ -30,7 +30,7 @@ struct ReluFunctor {
* @b 特化到具体平台的实现, param 从 op 层传入 * @b 特化到具体平台的实现, param 从 op 层传入
* */ * */
template <typename P> template <typename P>
void ReluCompute(const ReluParam &param) { void ReluCompute(const ReluParam<CPU> &param) {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>(); auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out(); auto *out = param.Out();
......
...@@ -23,7 +23,7 @@ namespace paddle_mobile { ...@@ -23,7 +23,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void ReshapeCompute(const ReshapeParam &param) { void ReshapeCompute(const ReshapeParam<CPU> &param) {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
const auto &input_x_dims = input_x->dims(); const auto &input_x_dims = input_x->dims();
auto *out = param.Out(); auto *out = param.Out();
......
...@@ -73,7 +73,7 @@ void sigmoid(const Tensor *X, Tensor *Y) { ...@@ -73,7 +73,7 @@ void sigmoid(const Tensor *X, Tensor *Y) {
} }
template <typename P> template <typename P>
void SigmoidCompute(const SigmoidParam &param) { void SigmoidCompute(const SigmoidParam<CPU> &param) {
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.InputX();
Tensor *out = param.Out(); Tensor *out = param.Out();
auto x_dims = in_x->dims(); auto x_dims = in_x->dims();
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <typename P> template <typename P>
void SoftmaxCompute(const SoftmaxParam &param) { void SoftmaxCompute(const SoftmaxParam<CPU> &param) {
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.InputX();
Tensor *out = param.Out(); Tensor *out = param.Out();
auto x_dims = in_x->dims(); auto x_dims = in_x->dims();
......
...@@ -39,7 +39,7 @@ namespace operators { ...@@ -39,7 +39,7 @@ namespace operators {
// } // }
template <typename P> template <typename P>
void TransposeCompute(const TransposeParam& param) { void TransposeCompute(const TransposeParam<CPU>& param) {
const auto* input_x = param.InputX(); const auto* input_x = param.InputX();
const auto input_x_dims = input_x->dims(); const auto input_x_dims = input_x->dims();
auto* out = param.Out(); auto* out = param.Out();
......
...@@ -24,10 +24,11 @@ namespace operators { ...@@ -24,10 +24,11 @@ namespace operators {
using namespace framework; using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> { class ConcatKernel
: public framework::OpKernelBase<DeviceType, ConcatParam<DeviceType>> {
public: public:
void Compute(const ConcatParam &param) const; void Compute(const ConcatParam<DeviceType> &param) const;
bool Init(ConcatParam *param); bool Init(ConcatParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -32,10 +32,11 @@ using framework::DDim; ...@@ -32,10 +32,11 @@ using framework::DDim;
using framework::OpKernelBase; using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvAddBNKernel : public OpKernelBase<DeviceType, FusionConvAddBNParam> { class ConvAddBNKernel
: public OpKernelBase<DeviceType, FusionConvAddBNParam<DeviceType>> {
public: public:
void Compute(const FusionConvAddBNParam &param) const; void Compute(const FusionConvAddBNParam<DeviceType> &param) const;
bool Init(FusionConvAddBNParam *param); bool Init(FusionConvAddBNParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -33,10 +33,10 @@ using framework::OpKernelBase; ...@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvAddBNReluKernel class ConvAddBNReluKernel
: public OpKernelBase<DeviceType, FusionConvAddBNReluParam> { : public OpKernelBase<DeviceType, FusionConvAddBNReluParam<DeviceType>> {
public: public:
void Compute(const FusionConvAddBNReluParam &param) const; void Compute(const FusionConvAddBNReluParam<DeviceType> &param) const;
bool Init(FusionConvAddBNReluParam *param); bool Init(FusionConvAddBNReluParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -37,10 +37,11 @@ using framework::DDim; ...@@ -37,10 +37,11 @@ using framework::DDim;
using framework::OpKernelBase; using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> { class ConvAddKernel
: public OpKernelBase<DeviceType, FusionConvAddParam<DeviceType>> {
public: public:
void Compute(const FusionConvAddParam &param) const; void Compute(const FusionConvAddParam<DeviceType> &param) const;
bool Init(FusionConvAddParam *param); bool Init(FusionConvAddParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -33,10 +33,10 @@ using framework::OpKernelBase; ...@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvAddReluKernel class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> { : public OpKernelBase<DeviceType, FusionConvAddReluParam<DeviceType>> {
public: public:
void Compute(const FusionConvAddReluParam &param) const; void Compute(const FusionConvAddReluParam<DeviceType> &param) const;
bool Init(FusionConvAddReluParam *param); bool Init(FusionConvAddReluParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -33,10 +33,10 @@ using framework::OpKernelBase; ...@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvBNReluKernel class ConvBNReluKernel
: public OpKernelBase<DeviceType, FusionConvBNReluParam> { : public OpKernelBase<DeviceType, FusionConvBNReluParam<DeviceType>> {
public: public:
void Compute(const FusionConvBNReluParam &param) const; void Compute(const FusionConvBNReluParam<DeviceType> &param) const;
bool Init(FusionConvBNReluParam *param); bool Init(FusionConvBNReluParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -29,10 +29,10 @@ namespace operators { ...@@ -29,10 +29,10 @@ namespace operators {
using framework::OpKernelBase; using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> { class ConvKernel : public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam<DeviceType> &param) const;
bool Init(ConvParam *param); bool Init(ConvParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -26,11 +26,11 @@ using framework::OpKernelBase; ...@@ -26,11 +26,11 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ConvTransposeKernel class ConvTransposeKernel
: public OpKernelBase<DeviceType, ConvTransposeParam> { : public OpKernelBase<DeviceType, ConvTransposeParam<DeviceType>> {
public: public:
void Compute(const ConvTransposeParam &param) const; void Compute(const ConvTransposeParam<DeviceType> &param) const;
bool Init(ConvTransposeParam *param); bool Init(ConvTransposeParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -28,10 +28,11 @@ namespace operators { ...@@ -28,10 +28,11 @@ namespace operators {
using framework::OpKernelBase; using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> { class DepthwiseConvKernel
: public OpKernelBase<DeviceType, ConvParam<DeviceType>> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam<DeviceType> &param) const;
bool Init(ConvParam *param); bool Init(ConvParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -23,10 +23,11 @@ namespace paddle_mobile { ...@@ -23,10 +23,11 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> { class DropoutKernel
: public framework::OpKernelBase<DeviceType, DropoutParam<DeviceType>> {
public: public:
void Compute(const DropoutParam& param) const; void Compute(const DropoutParam<DeviceType>& param) const;
bool Init(DropoutParam* para); bool Init(DropoutParam<DeviceType>* para);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -33,10 +33,10 @@ using framework::OpKernelBase; ...@@ -33,10 +33,10 @@ using framework::OpKernelBase;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DWConvBNReluKernel class DWConvBNReluKernel
: public OpKernelBase<DeviceType, FusionDWConvBNReluParam> { : public OpKernelBase<DeviceType, FusionDWConvBNReluParam<DeviceType>> {
public: public:
void Compute(const FusionDWConvBNReluParam &param) const; void Compute(const FusionDWConvBNReluParam<DeviceType> &param) const;
bool Init(FusionDWConvBNReluParam *param); bool Init(FusionDWConvBNReluParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
......
...@@ -27,10 +27,11 @@ using namespace framework; ...@@ -27,10 +27,11 @@ using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ElementwiseAddKernel class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> { : public framework::OpKernelBase<DeviceType,
ElementwiseAddParam<DeviceType>> {
public: public:
void Compute(const ElementwiseAddParam &param) const; void Compute(const ElementwiseAddParam<DeviceType> &param) const;
bool Init(ElementwiseAddParam *param); bool Init(ElementwiseAddParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -26,10 +26,11 @@ using namespace framework; ...@@ -26,10 +26,11 @@ using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class ElementwiseAddReluKernel class ElementwiseAddReluKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddReluParam> { : public framework::OpKernelBase<DeviceType,
ElementwiseAddReluParam<DeviceType>> {
public: public:
void Compute(const ElementwiseAddReluParam &param) const; void Compute(const ElementwiseAddReluParam<DeviceType> &param) const;
bool Init(ElementwiseAddReluParam *param); bool Init(ElementwiseAddReluParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -25,10 +25,11 @@ namespace operators { ...@@ -25,10 +25,11 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionFcReluKernel class FusionFcReluKernel
: public framework::OpKernelBase<DeviceType, FusionFcReluParam> { : public framework::OpKernelBase<DeviceType,
FusionFcReluParam<DeviceType>> {
public: public:
void Compute(const FusionFcReluParam& param) const; void Compute(const FusionFcReluParam<DeviceType>& param) const;
bool Init(FusionFcReluParam* param); bool Init(FusionFcReluParam<DeviceType>* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -20,12 +20,12 @@ namespace paddle_mobile { ...@@ -20,12 +20,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConcatKernel<FPGA, float>::Init(ConcatParam *param) { bool ConcatKernel<FPGA, float>::Init(ConcatParam<FPGA> *param) {
return true; return true;
} }
template <> template <>
void ConcatKernel<FPGA, float>::Compute(const ConcatParam &param) const { void ConcatKernel<FPGA, float>::Compute(const ConcatParam<FPGA> &param) const {
auto inputs = param.Inputs(); auto inputs = param.Inputs();
auto *out = param.Out(); auto *out = param.Out();
int64_t axis = param.Axis(); int64_t axis = param.Axis();
......
...@@ -22,7 +22,7 @@ namespace paddle_mobile { ...@@ -22,7 +22,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) { bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam<FPGA> *param) {
bool relu_enabled = false; bool relu_enabled = false;
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<half>(); auto input_ptr = input->data<half>();
...@@ -92,7 +92,7 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) { ...@@ -92,7 +92,7 @@ bool ConvAddBNKernel<FPGA, float>::Init(FusionConvAddBNParam *param) {
template <> template <>
void ConvAddBNKernel<FPGA, float>::Compute( void ConvAddBNKernel<FPGA, float>::Compute(
const FusionConvAddBNParam &param) const { const FusionConvAddBNParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs()); fpga::ComputeFpgaConv(param.FpgaArgs());
} }
template class ConvAddBNKernel<FPGA, float>; template class ConvAddBNKernel<FPGA, float>;
......
...@@ -21,7 +21,8 @@ namespace paddle_mobile { ...@@ -21,7 +21,8 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) { bool ConvAddBNReluKernel<FPGA, float>::Init(
FusionConvAddBNReluParam<FPGA> *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<half>(); auto input_ptr = input->data<half>();
...@@ -83,7 +84,7 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) { ...@@ -83,7 +84,7 @@ bool ConvAddBNReluKernel<FPGA, float>::Init(FusionConvAddBNReluParam *param) {
template <> template <>
void ConvAddBNReluKernel<FPGA, float>::Compute( void ConvAddBNReluKernel<FPGA, float>::Compute(
const FusionConvAddBNReluParam &param) const { const FusionConvAddBNReluParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs()); fpga::ComputeFpgaConv(param.FpgaArgs());
} }
template class ConvAddBNReluKernel<FPGA, float>; template class ConvAddBNReluKernel<FPGA, float>;
......
...@@ -21,7 +21,7 @@ namespace paddle_mobile { ...@@ -21,7 +21,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) { bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam<FPGA> *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<half>(); auto input_ptr = input->data<half>();
...@@ -67,7 +67,7 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) { ...@@ -67,7 +67,7 @@ bool ConvAddReluKernel<FPGA, float>::Init(FusionConvAddReluParam *param) {
template <> template <>
void ConvAddReluKernel<FPGA, float>::Compute( void ConvAddReluKernel<FPGA, float>::Compute(
const FusionConvAddReluParam &param) const { const FusionConvAddReluParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs()); fpga::ComputeFpgaConv(param.FpgaArgs());
} }
template class ConvAddReluKernel<FPGA, float>; template class ConvAddReluKernel<FPGA, float>;
......
...@@ -21,12 +21,12 @@ namespace paddle_mobile { ...@@ -21,12 +21,12 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool ConvKernel<FPGA, float>::Init(ConvParam *param) { bool ConvKernel<FPGA, float>::Init(ConvParam<FPGA> *param) {
return true; return true;
} }
template <> template <>
void ConvKernel<FPGA, float>::Compute(const ConvParam &param) const { void ConvKernel<FPGA, float>::Compute(const ConvParam<FPGA> &param) const {
// ConvCompute<float>(param); // ConvCompute<float>(param);
} }
......
...@@ -20,13 +20,14 @@ namespace paddle_mobile { ...@@ -20,13 +20,14 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool DropoutKernel<FPGA, float>::Init(DropoutParam *param) { bool DropoutKernel<FPGA, float>::Init(DropoutParam<FPGA> *param) {
param->Out()->ShareDataWith(*param->InputX()); param->Out()->ShareDataWith(*param->InputX());
return true; return true;
} }
template <> template <>
void DropoutKernel<FPGA, float>::Compute(const DropoutParam &param) const { void DropoutKernel<FPGA, float>::Compute(
const DropoutParam<FPGA> &param) const {
// auto *input_x = param.InputX(); // auto *input_x = param.InputX();
// auto *out = param.Out(); // auto *out = param.Out();
// auto input_x_ptr = input_x->data<float>(); // auto input_x_ptr = input_x->data<float>();
......
...@@ -20,7 +20,7 @@ namespace operators { ...@@ -20,7 +20,7 @@ namespace operators {
template <> template <>
bool ElementwiseAddReluKernel<FPGA, float>::Init( bool ElementwiseAddReluKernel<FPGA, float>::Init(
ElementwiseAddReluParam *param) { ElementwiseAddReluParam<FPGA> *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
const Tensor *input_y = param->InputY(); const Tensor *input_y = param->InputY();
...@@ -57,7 +57,7 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init( ...@@ -57,7 +57,7 @@ bool ElementwiseAddReluKernel<FPGA, float>::Init(
template <> template <>
void ElementwiseAddReluKernel<FPGA, float>::Compute( void ElementwiseAddReluKernel<FPGA, float>::Compute(
const ElementwiseAddReluParam &param) const { const ElementwiseAddReluParam<FPGA> &param) const {
fpga::ComputeFpgaEWAdd(param.FpgaArgs()); fpga::ComputeFpgaEWAdd(param.FpgaArgs());
} }
} // namespace operators } // namespace operators
......
...@@ -19,7 +19,7 @@ namespace paddle_mobile { ...@@ -19,7 +19,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) { bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam<FPGA> *param) {
bool relu_enabled = true; bool relu_enabled = true;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<half>(); auto input_x_ptr = input_x->data<half>();
...@@ -66,7 +66,7 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) { ...@@ -66,7 +66,7 @@ bool FusionFcReluKernel<FPGA, float>::Init(FusionFcReluParam *param) {
} }
template <> template <>
void FusionFcReluKernel<FPGA, float>::Compute( void FusionFcReluKernel<FPGA, float>::Compute(
const FusionFcReluParam &param) const { const FusionFcReluParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs()); fpga::ComputeFpgaConv(param.FpgaArgs());
}; };
......
...@@ -19,7 +19,7 @@ namespace paddle_mobile { ...@@ -19,7 +19,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { bool FusionFcKernel<FPGA, float>::Init(FusionFcParam<FPGA> *param) {
bool relu_enabled = false; bool relu_enabled = false;
const Tensor *input_x = param->InputX(); const Tensor *input_x = param->InputX();
auto input_x_ptr = input_x->data<half>(); auto input_x_ptr = input_x->data<half>();
...@@ -65,7 +65,8 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) { ...@@ -65,7 +65,8 @@ bool FusionFcKernel<FPGA, float>::Init(FusionFcParam *param) {
} }
template <> template <>
void FusionFcKernel<FPGA, float>::Compute(const FusionFcParam &param) const { void FusionFcKernel<FPGA, float>::Compute(
const FusionFcParam<FPGA> &param) const {
fpga::ComputeFpgaConv(param.FpgaArgs()); fpga::ComputeFpgaConv(param.FpgaArgs());
} }
} // namespace operators } // namespace operators
......
...@@ -20,7 +20,7 @@ namespace paddle_mobile { ...@@ -20,7 +20,7 @@ namespace paddle_mobile {
namespace operators { namespace operators {
template <> template <>
bool PoolKernel<FPGA, float>::Init(PoolParam *param) { bool PoolKernel<FPGA, float>::Init(PoolParam<FPGA> *param) {
const Tensor *input = param->Input(); const Tensor *input = param->Input();
auto input_ptr = input->data<half>(); auto input_ptr = input->data<half>();
Tensor *output = param->Output(); Tensor *output = param->Output();
...@@ -46,7 +46,7 @@ bool PoolKernel<FPGA, float>::Init(PoolParam *param) { ...@@ -46,7 +46,7 @@ bool PoolKernel<FPGA, float>::Init(PoolParam *param) {
} }
template <> template <>
void PoolKernel<FPGA, float>::Compute(const PoolParam &param) const { void PoolKernel<FPGA, float>::Compute(const PoolParam<FPGA> &param) const {
#ifdef PADDLE_MOBILE_FPGA #ifdef PADDLE_MOBILE_FPGA
fpga::ComputeFpgaPool(param.FpgaArgs()); fpga::ComputeFpgaPool(param.FpgaArgs());
#endif #endif
......
...@@ -25,10 +25,10 @@ namespace operators { ...@@ -25,10 +25,10 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class FusionFcKernel class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> { : public framework::OpKernelBase<DeviceType, FusionFcParam<DeviceType>> {
public: public:
void Compute(const FusionFcParam& param) const; void Compute(const FusionFcParam<DeviceType>& param) const;
bool Init(FusionFcParam* param); bool Init(FusionFcParam<DeviceType>* param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -27,10 +27,10 @@ namespace operators { ...@@ -27,10 +27,10 @@ namespace operators {
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class Im2SequenceKernel class Im2SequenceKernel
: public framework::OpKernelBase<DeviceType, Im2SequenceParam> { : public framework::OpKernelBase<DeviceType, Im2SequenceParam<DeviceType>> {
public: public:
void Compute(const Im2SequenceParam& param) const; void Compute(const Im2SequenceParam<DeviceType>& param) const;
bool Init(Im2SequenceParam* para); bool Init(Im2SequenceParam<DeviceType>* para);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -170,10 +170,11 @@ struct LRNFunctor { ...@@ -170,10 +170,11 @@ struct LRNFunctor {
}; };
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> { class LrnKernel
: public framework::OpKernelBase<DeviceType, LrnParam<DeviceType>> {
public: public:
void Compute(const LrnParam &param) const; void Compute(const LrnParam<DeviceType> &param) const;
bool Init(LrnParam *param); bool Init(LrnParam<DeviceType> *param);
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册