From d9a7ecd4178eebc900974b711f1f50b951f8eaa4 Mon Sep 17 00:00:00 2001 From: "Yao,kun" Date: Fri, 29 Jun 2018 18:49:08 +0800 Subject: [PATCH] Add dropout op --- src/common/types.cpp | 4 +- src/common/types.h | 1 + src/operators/dropout_op.cpp | 30 ++++++++++++ src/operators/dropout_op.h | 51 +++++++++++++++++++++ src/operators/kernel/arm/dropout_kernel.cpp | 41 +++++++++++++++++ src/operators/kernel/dropout_kernel.h | 29 ++++++++++++ src/operators/op_param.h | 17 +++++++ 7 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 src/operators/dropout_op.cpp create mode 100644 src/operators/dropout_op.h create mode 100644 src/operators/kernel/arm/dropout_kernel.cpp create mode 100644 src/operators/kernel/dropout_kernel.h diff --git a/src/common/types.cpp b/src/common/types.cpp index 8c06b6f0f6..8e8084fc57 100644 --- a/src/common/types.cpp +++ b/src/common/types.cpp @@ -40,6 +40,7 @@ const std::string G_OP_TYPE_FEED = "feed"; const std::string G_OP_TYPE_FETCH = "fetch"; const std::string G_OP_TYPE_DEPTHWISE_CONV = "depthwise_conv2d"; const std::string G_OP_TYPE_IM2SEQUENCE = "im2sequence"; +const std::string G_OP_TYPE_DROPOUT = "dropout"; std::unordered_map< std::string, std::pair, std::vector>> @@ -66,6 +67,7 @@ std::unordered_map< {G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}}, {G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}}, {G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}}, - {G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}}}; + {G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}}, + {G_OP_TYPE_DROPOUT, {{"X"}, {"Out"}}}}; } // namespace paddle_mobile diff --git a/src/common/types.h b/src/common/types.h index e632c0b52f..045de236ed 100644 --- a/src/common/types.h +++ b/src/common/types.h @@ -95,6 +95,7 @@ extern const std::string G_OP_TYPE_FEED; extern const std::string G_OP_TYPE_FETCH; extern const std::string G_OP_TYPE_DEPTHWISE_CONV; extern const std::string G_OP_TYPE_IM2SEQUENCE; +extern const std::string G_OP_TYPE_DROPOUT; extern std::unordered_map< std::string, std::pair, std::vector>> diff --git a/src/operators/dropout_op.cpp b/src/operators/dropout_op.cpp new file mode 100644 index 0000000000..5f68b5f8b5 --- /dev/null +++ b/src/operators/dropout_op.cpp @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "operators/dropout_op.h" +namespace paddle_mobile { +namespace operators { + +template +void DropoutOp::InferShape() const { + auto input_dims = param_.InputX()->dims(); + param_.Out()->Resize(input_dims); +} +template class DropoutOp; +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; +USE_OP(Dropout); +REGISTER_OPERATOR(dropout, ops::DropoutOp); diff --git a/src/operators/dropout_op.h b/src/operators/dropout_op.h new file mode 100644 index 0000000000..3943eed162 --- /dev/null +++ b/src/operators/dropout_op.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "framework/operator.h" +#include "operators/kernel/dropout_kernel.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +using paddle_mobile::framework::Tensor; + +template +class DropoutOp : public framework::OperatorWithKernel { + public: + DropoutOp(const std::string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const framework::AttributeMap attrs, + std::shared_ptr scope) + : framework::OperatorWithKernel(type, inputs, outputs, attrs, + scope), + param_(inputs, outputs, attrs, *scope) {} + + void Run() const { + operators::DropoutKernel kernel; + kernel.Compute(param_); + } + + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape() const override; + + protected: + DropoutParam param_; +}; + +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/arm/dropout_kernel.cpp b/src/operators/kernel/arm/dropout_kernel.cpp new file mode 100644 index 0000000000..342d5ea582 --- /dev/null +++ b/src/operators/kernel/arm/dropout_kernel.cpp @@ -0,0 +1,41 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "operators/kernel/dropout_kernel.h" +#include + +namespace paddle_mobile { +namespace operators { + +template +struct DropoutFunctor { + inline T operator()(T in) const { return in; } +}; + +template <> +void DropoutKernel::Compute(const DropoutParam ¶m) const { + const auto *input_x = param.InputX(); + auto *input_x_ptr = input_x->data(); + auto *out = param.Out(); + auto *out_ptr = out->mutable_data(); + + DropoutFunctor func_; + math::Transform trans; + trans(input_x_ptr, input_x_ptr + input_x->numel(), out_ptr, func_); + +} +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/kernel/dropout_kernel.h b/src/operators/kernel/dropout_kernel.h new file mode 100644 index 0000000000..92caab1c7a --- /dev/null +++ b/src/operators/kernel/dropout_kernel.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "framework/operator.h" +#include "operators/op_param.h" + +#pragma once; + +namespace paddle_mobile { +namespace operators { + +template +class DropoutKernel : public framework::OpKernelBase { + public: + void Compute(const DropoutParam& param) const; +}; +} // namespace operators +} // namespace paddle_mobile diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 9955462e86..761d756161 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -877,6 +877,23 @@ class Im2SequenceParam : public OpParam { vector strides_; vector paddings_; }; + +class DropoutParam : public OpParam { +public: + DropoutParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_x_ = InputXFrom(inputs, scope); + out_ = OutFrom(outputs, scope); + } + + const Tensor *InputX() const { return input_x_; } + + Tensor *Out() const { return out_; } + +private: + Tensor *input_x_; + Tensor *out_; +}; } // namespace operators } // namespace paddle_mobile -- GitLab