提交 d9a7ecd4 编写于 作者: Y Yao,kun

Add dropout op

上级 e7b73649
......@@ -40,6 +40,7 @@ const std::string G_OP_TYPE_FEED = "feed";
const std::string G_OP_TYPE_FETCH = "fetch";
const std::string G_OP_TYPE_DEPTHWISE_CONV = "depthwise_conv2d";
const std::string G_OP_TYPE_IM2SEQUENCE = "im2sequence";
const std::string G_OP_TYPE_DROPOUT = "dropout";
std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
......@@ -66,6 +67,7 @@ std::unordered_map<
{G_OP_TYPE_RESHAPE, {{"X"}, {"Out"}}},
{G_OP_TYPE_DEPTHWISE_CONV, {{"Input"}, {"Output"}}},
{G_OP_TYPE_FUSION_CONV_ADD_RELU, {{"Input"}, {"Out"}}},
{G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}}};
{G_OP_TYPE_IM2SEQUENCE, {{"X"}, {"Out"}}},
{G_OP_TYPE_DROPOUT, {{"X"}, {"Out"}}}};
} // namespace paddle_mobile
......@@ -95,6 +95,7 @@ extern const std::string G_OP_TYPE_FEED;
extern const std::string G_OP_TYPE_FETCH;
extern const std::string G_OP_TYPE_DEPTHWISE_CONV;
extern const std::string G_OP_TYPE_IM2SEQUENCE;
extern const std::string G_OP_TYPE_DROPOUT;
extern std::unordered_map<
std::string, std::pair<std::vector<std::string>, std::vector<std::string>>>
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "operators/dropout_op.h"
namespace paddle_mobile {
namespace operators {
template <typename Dtype, typename T>
void DropoutOp<Dtype, T>::InferShape() const {
auto input_dims = param_.InputX()->dims();
param_.Out()->Resize(input_dims);
}
template class DropoutOp<CPU, float>;
} // namespace operators
} // namespace paddle_mobile
namespace ops = paddle_mobile::operators;
USE_OP(Dropout);
REGISTER_OPERATOR(dropout, ops::DropoutOp);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "framework/operator.h"
#include "operators/kernel/dropout_kernel.h"
#include "operators/op_param.h"
namespace paddle_mobile {
namespace operators {
using paddle_mobile::framework::Tensor;
template <typename DeviceType, typename T>
class DropoutOp : public framework::OperatorWithKernel<DeviceType> {
public:
DropoutOp(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const framework::AttributeMap attrs,
std::shared_ptr<framework::Scope> scope)
: framework::OperatorWithKernel<DeviceType>(type, inputs, outputs, attrs,
scope),
param_(inputs, outputs, attrs, *scope) {}
void Run() const {
operators::DropoutKernel<DeviceType, T> kernel;
kernel.Compute(param_);
}
using framework::OperatorWithKernel<DeviceType>::OperatorWithKernel;
void InferShape() const override;
protected:
DropoutParam param_;
};
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "operators/kernel/dropout_kernel.h"
#include <operators/math/transform.h>
namespace paddle_mobile {
namespace operators {
template <typename T>
struct DropoutFunctor {
inline T operator()(T in) const { return in; }
};
template <>
void DropoutKernel<CPU, float>::Compute(const DropoutParam &param) const {
const auto *input_x = param.InputX();
auto *input_x_ptr = input_x->data<float>();
auto *out = param.Out();
auto *out_ptr = out->mutable_data<float>();
DropoutFunctor<float> func_;
math::Transform trans;
trans(input_x_ptr, input_x_ptr + input_x->numel(), out_ptr, func_);
}
} // namespace operators
} // namespace paddle_mobile
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "framework/operator.h"
#include "operators/op_param.h"
#pragma once;
namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
class DropoutKernel : public framework::OpKernelBase<DeviceType, DropoutParam> {
public:
void Compute(const DropoutParam& param) const;
};
} // namespace operators
} // namespace paddle_mobile
......@@ -877,6 +877,23 @@ class Im2SequenceParam : public OpParam {
vector<int> strides_;
vector<int> paddings_;
};
class DropoutParam : public OpParam {
public:
DropoutParam(const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs, const Scope &scope) {
input_x_ = InputXFrom<LoDTensor>(inputs, scope);
out_ = OutFrom<LoDTensor>(outputs, scope);
}
const Tensor *InputX() const { return input_x_; }
Tensor *Out() const { return out_; }
private:
Tensor *input_x_;
Tensor *out_;
};
} // namespace operators
} // namespace paddle_mobile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册