diff --git a/paddle/fluid/operators/ngraph/ops/cast_op.h b/paddle/fluid/operators/ngraph/ops/cast_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8e385f61bee10b8d4dfb2fdcc723637a6f3c2a07 --- /dev/null +++ b/paddle/fluid/operators/ngraph/ops/cast_op.h @@ -0,0 +1,62 @@ +/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "ngraph/ngraph.hpp" +#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" +#include "paddle/fluid/platform/ngraph_helper.h" + +namespace paddle { +namespace operators { +namespace ngraphs { + +static void BuildCastNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "X", ngb_node_map); + auto op_attrs = framework::AttrReader(op->Attrs()); + auto ng_dtype = + platform::GetNgType(static_cast( + op_attrs.Get("out_dtype"))); + auto out = std::make_shared(input, ng_dtype); + paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); +} + +static void BuildCastGradNode( + const std::shared_ptr& op, + std::shared_ptr< + std::unordered_map>> + ngb_node_map) { + auto input = platform::GetInputNode(op, "Out@GRAD", ngb_node_map); + auto op_attrs = framework::AttrReader(op->Attrs()); + auto ng_dtype = + platform::GetNgType(static_cast( + op_attrs.Get("out_dtype"))); + auto out = std::make_shared(input, ng_dtype); + platform::SetOutputNode(op, "X@GRAD", out, ngb_node_map); +} +} // namespace ngraphs +} // namespace operators +} // namespace paddle + +REGISTER_NG_OP(cast, BuildCastNode); +REGISTER_NG_OP(cast_grad, BuildCastGradNode); diff --git a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h index 42c2df5259242b7ae28613ab12c237834febc574..fee5f57e4862a8a033a28885a01a0dafea35f7f0 100644 --- a/paddle/fluid/operators/ngraph/ops/fill_constant_op.h +++ b/paddle/fluid/operators/ngraph/ops/fill_constant_op.h @@ -38,20 +38,9 @@ void BuildFillConstantNode( shape.push_back(sp); } float value = op_attrs.Get("value"); - ngraph::element::Type ng_dtype; - auto data_type = static_cast( - op_attrs.Get("dtype")); - if (data_type == paddle::framework::proto::VarType::FP32) { - ng_dtype = ngraph::element::f32; - } else if (data_type == paddle::framework::proto::VarType::FP64) { - ng_dtype = ngraph::element::f64; - } else if (data_type == paddle::framework::proto::VarType::INT64) { - ng_dtype = ngraph::element::i64; - } else if (data_type == paddle::framework::proto::VarType::INT32) { - ng_dtype = ngraph::element::i32; - } else { - PADDLE_THROW("unsupported data type: %s", data_type); - } + auto ng_dtype = + platform::GetNgType(static_cast( + op_attrs.Get("dtype"))); auto out = ngraph::op::Constant::create(ng_dtype, shape, {value}); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); } diff --git a/python/paddle/fluid/tests/unittests/ngraph/test_cast_ngraph_op.py b/python/paddle/fluid/tests/unittests/ngraph/test_cast_ngraph_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7732637d2299b0e9ea2092e4d244e09cc8a21c0e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ngraph/test_cast_ngraph_op.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from paddle.fluid.tests.unittests.test_cast_op import TestCastOp1 + +if __name__ == '__main__': + unittest.main()