提交 094623ee 编写于 作者: Y Yao,kun

Add op init

上级 9b60c8ee
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DROPOUT_OP
#include "operators/dropout_op.h"
namespace paddle_mobile {
namespace operators {
......@@ -26,5 +27,13 @@ template class DropoutOp<CPU, float>;
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP(Dropout);
REGISTER_OPERATOR(dropout, ops::DropoutOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef DROPOUT_OP
#pragma once
#include <string>
......@@ -26,7 +28,8 @@ namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class DropoutOp : public framework::OperatorWithKernel<DeviceType> {
class DropoutOp : public framework::OperatorWithKernel<DeviceType, DropoutParam,
operators::DropoutKernal<DeviceType, T>> {
public:
DropoutOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
......@@ -35,17 +38,15 @@ class DropoutOp : public framework::OperatorWithKernel<DeviceType> {
scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::DropoutKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
using framework::OperatorWithKernel<DeviceType, DropoutParam,
operators::DropoutKernel<DeviceType, T>>;
void InferShape() const override;
protected:
DropoutParam param_;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IM2SEQUENCE_OP
#include "operators/im2sequence_op.h"
namespace paddle_mobile {
......@@ -51,5 +53,13 @@ template class Im2SequenceOp<CPU, float>;
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
#ifdef PADDLE_MOBILE_CPU
USE_OP(im2sequence);
REGISTER_OPERATOR(im2sequence, ops::Im2SequenceOp);
#endif
#ifdef PADDLE_MOBILE_MALI_GPU
#endif
#ifdef PADDLE_MOBILE_FPGA
#endif
#endif
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef IM2SEQUENCE_OP
#pragma once
#include <operators/op_param.h>
......@@ -24,7 +26,8 @@ namespace operators {
using namespace framework;
template <typename DeviceType, typename T>
class Im2SequenceOp : public framework::OperatorWithKernel<DeviceType> {
class Im2SequenceOp : public framework::OperatorWithKernel<DeviceType, Im2SequenceParam,
operators::Im2SequenceKernal<DeviceType, T>> {
public:
Im2SequenceOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
......@@ -34,18 +37,14 @@ class Im2SequenceOp : public framework::OperatorWithKernel<DeviceType> {
scope),
param_(inputs, outputs, attrs, *scope) {}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
using framework::OperatorWithKernel<DeviceType, Im2SequenceParam,
operators::Im2SequenceKernel<DeviceType, T>>::OperatorWithKernel;
void InferShape() const override;
void RunImpl() const {
operators::Im2SequenceKernel<DeviceType, T> kernel;
kernel.Compute(param_);
this->ClearVariables({"X"});
}
private:
Im2SequenceParam param_;
};
} // namespace operators
} // namespace paddle_mobile
#endif
......@@ -20,6 +20,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool DropoutKernel<CPU, float>::Init(const DropoutParam &para) const {
return true;
}
template <typename T>
struct DropoutFunctor {
inline T operator()(T in) const { return in; }
......
......@@ -17,6 +17,11 @@ limitations under the License. */
namespace paddle_mobile {
namespace operators {
template <>
bool Im2SequenceKernel<CPU, float>::Init(const Im2SequenceParam &para) const {
return true;
}
inline int Im2SeqOutputSize(int input_size, int filter_size, int padding_0,
int padding_1, int stride) {
const int output_size =
......
......@@ -24,6 +24,7 @@ template <typename DeviceType, typename T>
class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> {
public:
void Compute(const DropoutParam& param) const;
bool Init(const DropoutParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -30,6 +30,7 @@ class Im2SequenceKernel
: public framework::OpKernelBase<DeviceType, Im2SequenceParam> {
public:
void Compute(const Im2SequenceParam &param) const;
bool Init(const Im2SequenceParam &para) const;
};
} // namespace operators
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册