提交 094623ee 编写于 作者: Y Yao,kun

Add op init

上级 9b60c8ee
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DROPOUT_OP
#include "operators/dropout_op.h" #include "operators/dropout_op.h"
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
...@@ -26,5 +27,13 @@ template class DropoutOp<CPU, float>; ...@@ -26,5 +27,13 @@ template class DropoutOp<CPU, float>;
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP(Dropout); USE_OP(Dropout);
REGISTER_OPERATOR(dropout, ops::DropoutOp); REGISTER_OPERATOR(dropout, ops::DropoutOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef DROPOUT_OP
#pragma once #pragma once
#include <string> #include <string>
...@@ -26,7 +28,8 @@ namespace operators { ...@@ -26,7 +28,8 @@ namespace operators {
using paddle_mobile::framework::Tensor; using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class DropoutOp : public framework::OperatorWithKernel<DeviceType> { class DropoutOp : public framework::OperatorWithKernel<DeviceType, DropoutParam,
operators::DropoutKernal<DeviceType, T>> {
public: public:
DropoutOp(const std::string &type, const VariableNameMap &inputs, DropoutOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs, const VariableNameMap &outputs, const framework::AttributeMap attrs,
...@@ -35,17 +38,15 @@ class DropoutOp : public framework::OperatorWithKernel<DeviceType> { ...@@ -35,17 +38,15 @@ class DropoutOp : public framework::OperatorWithKernel<DeviceType> {
scope), scope),
param_(inputs, outputs, attrs, *scope) {} param_(inputs, outputs, attrs, *scope) {}
void Run() const { using framework::OperatorWithKernel<DeviceType, DropoutParam,
operators::DropoutKernel<DeviceType, T> kernel; operators::DropoutKernel<DeviceType, T>>;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
protected: protected:
DropoutParam param_;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef IM2SEQUENCE_OP
#include "operators/im2sequence_op.h" #include "operators/im2sequence_op.h"
namespace paddle_mobile { namespace paddle_mobile {
...@@ -51,5 +53,13 @@ template class Im2SequenceOp<CPU, float>; ...@@ -51,5 +53,13 @@ template class Im2SequenceOp<CPU, float>;
} // namespace paddle_mobile } // namespace paddle_mobile
namespace ops = paddle_mobile::operators; namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP(im2sequence); USE_OP(im2sequence);
REGISTER_OPERATOR(im2sequence, ops::Im2SequenceOp); REGISTER_OPERATOR(im2sequence, ops::Im2SequenceOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef IM2SEQUENCE_OP
#pragma once #pragma once
#include <operators/op_param.h> #include <operators/op_param.h>
...@@ -24,7 +26,8 @@ namespace operators { ...@@ -24,7 +26,8 @@ namespace operators {
using namespace framework; using namespace framework;
template <typename DeviceType, typename T> template <typename DeviceType, typename T>
class Im2SequenceOp : public framework::OperatorWithKernel<DeviceType> { class Im2SequenceOp : public framework::OperatorWithKernel<DeviceType, Im2SequenceParam,
operators::Im2SequenceKernal<DeviceType, T>> {
public: public:
Im2SequenceOp(const std::string &type, const VariableNameMap &inputs, Im2SequenceOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs,
...@@ -34,18 +37,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel<DeviceType> { ...@@ -34,18 +37,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel<DeviceType> {
scope), scope),
param_(inputs, outputs, attrs, *scope) {} param_(inputs, outputs, attrs, *scope) {}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel; using framework::OperatorWithKernel<DeviceType, Im2SequenceParam,
operators::Im2SequenceKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override; void InferShape() const override;
void RunImpl() const {
operators::Im2SequenceKernel<DeviceType, T> kernel;
kernel.Compute(param_);
this->ClearVariables({"X"});
}
private: private:
Im2SequenceParam param_;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
#endif
...@@ -20,6 +20,11 @@ limitations under the License. */ ...@@ -20,6 +20,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool DropoutKernel<CPU, float>::Init(const DropoutParam &para) const {
return true;
}
template <typename T> template <typename T>
struct DropoutFunctor { struct DropoutFunctor {
inline T operator()(T in) const { return in; } inline T operator()(T in) const { return in; }
......
...@@ -17,6 +17,11 @@ limitations under the License. */ ...@@ -17,6 +17,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool Im2SequenceKernel<CPU, float>::Init(const Im2SequenceParam &para) const {
return true;
}
inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0, inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0,
int padding_1, int stride) { int padding_1, int stride) {
const int output_size = const int output_size =
......
...@@ -24,6 +24,7 @@ template <typename DeviceType, typename T> ...@@ -24,6 +24,7 @@ template <typename DeviceType, typename T>
class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> { class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> {
public: public:
void Compute(const DropoutParam& param) const; void Compute(const DropoutParam& param) const;
bool Init(const DropoutParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
...@@ -30,6 +30,7 @@ class Im2SequenceKernel ...@@ -30,6 +30,7 @@ class Im2SequenceKernel
: public framework::OpKernelBase<DeviceType, Im2SequenceParam> { : public framework::OpKernelBase<DeviceType, Im2SequenceParam> {
public: public:
void Compute(const Im2SequenceParam &param) const; void Compute(const Im2SequenceParam &param) const;
bool Init(const Im2SequenceParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册