diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index 5ef385d2fcbaf01dce5c9b85321b41c103e5655a..c2dce51fe54d107e23b38e013e066e29df0c74ca 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -61,6 +61,7 @@ static std::map {framework::proto::VarType::FP64, ngraph::element::f64}, {framework::proto::VarType::INT32, ngraph::element::i32}, {framework::proto::VarType::INT64, ngraph::element::i64}, + {framework::proto::VarType::UINT8, ngraph::element::u8}, {framework::proto::VarType::BOOL, ngraph::element::boolean}}; static std::map @@ -69,6 +70,7 @@ static std::map {ngraph::element::f64, framework::proto::VarType::FP64}, {ngraph::element::i32, framework::proto::VarType::INT32}, {ngraph::element::i64, framework::proto::VarType::INT64}, + {ngraph::element::u8, framework::proto::VarType::UINT8}, {ngraph::element::boolean, framework::proto::VarType::BOOL}}; std::vector NgraphEngine::feed_vars = {}; diff --git a/paddle/fluid/operators/ngraph/ops/dropout_op.h b/paddle/fluid/operators/ngraph/ops/dropout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cf19a585735f72796ee1820d63574fd6e725fc2b --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/dropout_op.h @@ -0,0 +1,110 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "ngraph/op/experimental/generate_mask.hpp" +#include "paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +static void BuildDropoutNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "X", ngb_node_map); + auto op_attrs = framework::AttrReader(op->Attrs()); + auto dropout_prob = op_attrs.Get("dropout_prob"); + auto dropout_implementation = + op_attrs.Get("dropout_implementation"); + auto is_test = op_attrs.Get("is_test"); + auto seed = op_attrs.Get("seed"); + float value = 1.0f - dropout_prob; + bool upscale_in_train = (dropout_implementation == "upscale_in_train"); + + if (is_test) { + if (upscale_in_train) { + platform::SetOutputNode(op, "Out", input, ngb_node_map); + } else { + auto mask_val = paddle::platform::CreateConstant( + input->get_element_type(), input->get_shape(), {value}); + auto out = input * mask_val; + platform::SetOutputNode(op, "Out", out, ngb_node_map); + } + } else { + auto one = paddle::platform::CreateConstant(input->get_element_type(), + ngraph::Shape{}, {1}); + + auto gen_mask = std::make_shared( + one, input->get_shape(), input->get_element_type(), seed, value); + + if (upscale_in_train) { + auto mask_val = paddle::platform::CreateConstant( + input->get_element_type(), input->get_shape(), {value}); + + auto out = value ? input * gen_mask / mask_val : input * gen_mask; + platform::SetOutputNode(op, "Mask", gen_mask, ngb_node_map); + platform::SetOutputNode(op, "Out", out, ngb_node_map); + } else { + auto out = input * gen_mask; + platform::SetOutputNode(op, "Mask", gen_mask, ngb_node_map); + platform::SetOutputNode(op, "Out", out, ngb_node_map); + } + } +} + +static void BuildDropoutGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto dy = platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + auto mask = platform::GetInputNode(op, "Mask", ngb_node_map); + if (dy->get_element_type() != mask->get_element_type()) { + mask = std::make_shared(mask, dy->get_element_type()); + } + + auto op_attrs = framework::AttrReader(op->Attrs()); + auto dropout_prob = op_attrs.Get("dropout_prob"); + auto dropout_implementation = + op_attrs.Get("dropout_implementation"); + auto dx = dy * mask; + + if (dropout_implementation == "upscale_in_train") { + if (dropout_prob == 1.0f) { + dx = ElementwiseScalar(0., dy); + } else { + dx = + ElementwiseScalar(1. / (1. - dropout_prob), dx); + } + } + platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(dropout, BuildDropoutNode); +REGISTER_NG_OP(dropout_grad, BuildDropoutGradNode); diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_dropout_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_dropout_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0448bed10204fb8ddba8546608750313191c4cc9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_dropout_ngraph_op.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from paddle.fluid.tests.unittests.test_dropout_op import TestDropoutOp, TestDropoutOp2, TestDropoutOp3, TestDropoutOp4, TestDropoutOp5, TestDropoutOp6, TestDropoutOp7, TestDropoutOp8, TestDropoutOp9 + +if __name__ == '__main__': + unittest.main()