From 256ba7cbb8675cc217fb471d1f04cd98135ea3d6 Mon Sep 17 00:00:00 2001 From: baojun <32073718+baojun-nervana@users.noreply.github.com> Date: Tue, 16 Jul 2019 23:35:14 -0700 Subject: [PATCH] [NGraph] handle dim element 0 of ngraph op (#18568) --- .../fluid/operators/ngraph/ngraph_engine.cc | 20 ++----------------- paddle/fluid/operators/ngraph/ops/gather_op.h | 15 ++++++++++++++ .../fluid/operators/ngraph/ops/reshape_op.h | 20 ++++++++++++++----- .../fluid/operators/ngraph/ops/transpose_op.h | 13 ++++++++++-- 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index ae87687e34..3a94368657 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -38,14 +38,8 @@ namespace operators { static ngraph::Shape Ddim2Shape(const framework::DDim& dims) { ngraph::Shape sp; - if (dims.size() == 1 && dims[0] == 0) { - sp.emplace_back(0); - return sp; - } for (int i = 0; i < dims.size(); ++i) { - int k = dims[i]; - k = k == 0 ? 1 : k; - sp.emplace_back(k); + sp.emplace_back(dims[i]); } return sp; } @@ -639,17 +633,7 @@ void NgraphEngine::Run(const framework::Scope& scope, for (auto& op : fused_ops_) { framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_); - if (op->Type() == "reshape2_grad") { - auto xshape_name = op->Inputs().at("XShape").at(0); - auto* xshape_var = scope_.FindVar(xshape_name); - auto* xshape_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*xshape_var); - auto& xshape_ddim = xshape_tensor->dims(); - auto xgrad_name = op->Outputs().at(framework::GradVarName("X")).at(0); - auto* xgrad_var = scope_.FindVar(xgrad_name); - xgrad_var->GetMutable()->Resize(xshape_ddim); - } else { - op->RuntimeInferShape(scope_, place_, ctx); - } + op->RuntimeInferShape(scope_, place_, ctx); } std::vector> t_out = {}; diff --git a/paddle/fluid/operators/ngraph/ops/gather_op.h b/paddle/fluid/operators/ngraph/ops/gather_op.h index 5d6ac7d8ca..7d369b27d3 100644 --- a/paddle/fluid/operators/ngraph/ops/gather_op.h +++ b/paddle/fluid/operators/ngraph/ops/gather_op.h @@ -34,7 +34,15 @@ void BuildGatherNode( ngb_node_map) { auto x = platform::GetInputNode(op, "X", ngb_node_map); PADDLE_ENFORCE_NOT_NULL(x); + auto index = platform::GetInputNode(op, "Index", ngb_node_map); + auto& index_shape = index->get_shape(); + PADDLE_ENFORCE(index_shape.size() == 1 || + (index_shape.size() == 2 && index_shape[1] == 1)); + if (index_shape.size() == 2) { + index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]}); + } + auto out = std::make_shared(x, index); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); @@ -47,7 +55,14 @@ void BuildGatherGradNode( auto dout = platform::GetInputNode(op, "Out@GRAD", ngb_node_map); PADDLE_ENFORCE_NOT_NULL(dout); auto x = platform::GetInputNode(op, "X", ngb_node_map); + auto index = platform::GetInputNode(op, "Index", ngb_node_map); + auto& index_shape = index->get_shape(); + PADDLE_ENFORCE(index_shape.size() == 1 || + (index_shape.size() == 2 && index_shape[1] == 1)); + if (index_shape.size() == 2) { + index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]}); + } std::shared_ptr x0 = paddle::platform::CreateConstant( dout->get_element_type(), x->get_shape(), {0}); diff --git a/paddle/fluid/operators/ngraph/ops/reshape_op.h b/paddle/fluid/operators/ngraph/ops/reshape_op.h index be3d38af49..53a2aebe23 100644 --- a/paddle/fluid/operators/ngraph/ops/reshape_op.h +++ b/paddle/fluid/operators/ngraph/ops/reshape_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -76,7 +77,12 @@ static void BuildReshapeNode( } if (is_v2) { - platform::SetOutputNode(op, "XShape", input, ngb_node_map); + ngraph::Shape input_xshape(input_shape.size() + 1); + input_xshape[0] = 0; + std::copy(input_shape.begin(), input_shape.end(), input_xshape.begin() + 1); + auto xshape_node = std::make_shared( + input->get_element_type(), input_xshape, std::vector{}); + platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map); } platform::SetOutputNode(op, "Out", out, ngb_node_map); } @@ -88,13 +94,17 @@ void BuildReshapeGradNode( std::unordered_map>> ngb_node_map) { auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); - std::shared_ptr input; + ngraph::Shape out_shape; if (is_v2) { - input = paddle::platform::GetInputNode(op, "XShape", ngb_node_map); + auto& xshape = + platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape(); + out_shape.resize(xshape.size() - 1); + std::copy(xshape.begin() + 1, xshape.end(), out_shape.begin()); } else { - input = paddle::platform::GetInputNode(op, "X", ngb_node_map); + auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map); + out_shape = input->get_shape(); } - auto dx = platform::NgReshaper(dout, input->get_shape()); + auto dx = platform::NgReshaper(dout, out_shape); paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); } } // namespace ngraphs diff --git a/paddle/fluid/operators/ngraph/ops/transpose_op.h b/paddle/fluid/operators/ngraph/ops/transpose_op.h index 2d4e49f79b..7d9428977a 100644 --- a/paddle/fluid/operators/ngraph/ops/transpose_op.h +++ b/paddle/fluid/operators/ngraph/ops/transpose_op.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -50,7 +51,12 @@ static void BuildTransposeNode( x_transpose = platform::NgReshaper(x_transpose, x_reshape_shape); platform::SetOutputNode(op, "Out", x_transpose, ngb_node_map); if (is_v2) { - platform::SetOutputNode(op, "XShape", input, ngb_node_map); + ngraph::Shape input_xshape(input_shape.size() + 1); + input_xshape[0] = 0; + std::copy(input_shape.begin(), input_shape.end(), input_xshape.begin() + 1); + auto xshape_node = std::make_shared( + input->get_element_type(), input_xshape, std::vector{}); + platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map); } } @@ -71,7 +77,10 @@ static void BuildTransposeGradNode( ngraph::Shape out_shape; if (is_v2) { - out_shape = platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape(); + auto& xshape = + platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape(); + out_shape.resize(xshape.size() - 1); + std::copy(xshape.begin() + 1, xshape.end(), out_shape.begin()); } else { out_shape = platform::GetInputNode(op, "X", ngb_node_map)->get_shape(); } -- GitLab