提交 256ba7cb 编写于 作者: B baojun 提交者: tensor-tang

[NGraph] handle dim element 0 of ngraph op (#18568)

上级 a6d468a2
...@@ -38,14 +38,8 @@ namespace operators { ...@@ -38,14 +38,8 @@ namespace operators {
static ngraph::Shape Ddim2Shape(const framework::DDim& dims) { static ngraph::Shape Ddim2Shape(const framework::DDim& dims) {
ngraph::Shape sp; ngraph::Shape sp;
if (dims.size() == 1 && dims[0] == 0) {
sp.emplace_back(0);
return sp;
}
for (int i = 0; i < dims.size(); ++i) { for (int i = 0; i < dims.size(); ++i) {
int k = dims[i]; sp.emplace_back(dims[i]);
k = k == 0 ? 1 : k;
sp.emplace_back(k);
} }
return sp; return sp;
} }
...@@ -639,18 +633,8 @@ void NgraphEngine::Run(const framework::Scope& scope, ...@@ -639,18 +633,8 @@ void NgraphEngine::Run(const framework::Scope& scope,
for (auto& op : fused_ops_) { for (auto& op : fused_ops_) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_); framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
if (op->Type() == "reshape2_grad") {
auto xshape_name = op->Inputs().at("XShape").at(0);
auto* xshape_var = scope_.FindVar(xshape_name);
auto* xshape_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*xshape_var);
auto& xshape_ddim = xshape_tensor->dims();
auto xgrad_name = op->Outputs().at(framework::GradVarName("X")).at(0);
auto* xgrad_var = scope_.FindVar(xgrad_name);
xgrad_var->GetMutable<framework::LoDTensor>()->Resize(xshape_ddim);
} else {
op->RuntimeInferShape(scope_, place_, ctx); op->RuntimeInferShape(scope_, place_, ctx);
} }
}
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out = {}; std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out = {};
for (size_t i = 0; i < p_var_out->size(); ++i) { for (size_t i = 0; i < p_var_out->size(); ++i) {
......
...@@ -34,7 +34,15 @@ void BuildGatherNode( ...@@ -34,7 +34,15 @@ void BuildGatherNode(
ngb_node_map) { ngb_node_map) {
auto x = platform::GetInputNode(op, "X", ngb_node_map); auto x = platform::GetInputNode(op, "X", ngb_node_map);
PADDLE_ENFORCE_NOT_NULL(x); PADDLE_ENFORCE_NOT_NULL(x);
auto index = platform::GetInputNode(op, "Index", ngb_node_map); auto index = platform::GetInputNode(op, "Index", ngb_node_map);
auto& index_shape = index->get_shape();
PADDLE_ENFORCE(index_shape.size() == 1 ||
(index_shape.size() == 2 && index_shape[1] == 1));
if (index_shape.size() == 2) {
index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]});
}
auto out = std::make_shared<ngraph::op::Gather>(x, index); auto out = std::make_shared<ngraph::op::Gather>(x, index);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
...@@ -47,7 +55,14 @@ void BuildGatherGradNode( ...@@ -47,7 +55,14 @@ void BuildGatherGradNode(
auto dout = platform::GetInputNode(op, "Out@GRAD", ngb_node_map); auto dout = platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
PADDLE_ENFORCE_NOT_NULL(dout); PADDLE_ENFORCE_NOT_NULL(dout);
auto x = platform::GetInputNode(op, "X", ngb_node_map); auto x = platform::GetInputNode(op, "X", ngb_node_map);
auto index = platform::GetInputNode(op, "Index", ngb_node_map); auto index = platform::GetInputNode(op, "Index", ngb_node_map);
auto& index_shape = index->get_shape();
PADDLE_ENFORCE(index_shape.size() == 1 ||
(index_shape.size() == 2 && index_shape[1] == 1));
if (index_shape.size() == 2) {
index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]});
}
std::shared_ptr<ngraph::Node> x0 = paddle::platform::CreateConstant( std::shared_ptr<ngraph::Node> x0 = paddle::platform::CreateConstant(
dout->get_element_type(), x->get_shape(), {0}); dout->get_element_type(), x->get_shape(), {0});
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -76,7 +77,12 @@ static void BuildReshapeNode( ...@@ -76,7 +77,12 @@ static void BuildReshapeNode(
} }
if (is_v2) { if (is_v2) {
platform::SetOutputNode(op, "XShape", input, ngb_node_map); ngraph::Shape input_xshape(input_shape.size() + 1);
input_xshape[0] = 0;
std::copy(input_shape.begin(), input_shape.end(), input_xshape.begin() + 1);
auto xshape_node = std::make_shared<ngraph::op::Constant>(
input->get_element_type(), input_xshape, std::vector<std::string>{});
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
} }
platform::SetOutputNode(op, "Out", out, ngb_node_map); platform::SetOutputNode(op, "Out", out, ngb_node_map);
} }
...@@ -88,13 +94,17 @@ void BuildReshapeGradNode( ...@@ -88,13 +94,17 @@ void BuildReshapeGradNode(
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) { ngb_node_map) {
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map); auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
std::shared_ptr<ngraph::Node> input; ngraph::Shape out_shape;
if (is_v2) { if (is_v2) {
input = paddle::platform::GetInputNode(op, "XShape", ngb_node_map); auto& xshape =
platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape();
out_shape.resize(xshape.size() - 1);
std::copy(xshape.begin() + 1, xshape.end(), out_shape.begin());
} else { } else {
input = paddle::platform::GetInputNode(op, "X", ngb_node_map); auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map);
out_shape = input->get_shape();
} }
auto dx = platform::NgReshaper(dout, input->get_shape()); auto dx = platform::NgReshaper(dout, out_shape);
paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map); paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map);
} }
} // namespace ngraphs } // namespace ngraphs
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <algorithm>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -50,7 +51,12 @@ static void BuildTransposeNode( ...@@ -50,7 +51,12 @@ static void BuildTransposeNode(
x_transpose = platform::NgReshaper(x_transpose, x_reshape_shape); x_transpose = platform::NgReshaper(x_transpose, x_reshape_shape);
platform::SetOutputNode(op, "Out", x_transpose, ngb_node_map); platform::SetOutputNode(op, "Out", x_transpose, ngb_node_map);
if (is_v2) { if (is_v2) {
platform::SetOutputNode(op, "XShape", input, ngb_node_map); ngraph::Shape input_xshape(input_shape.size() + 1);
input_xshape[0] = 0;
std::copy(input_shape.begin(), input_shape.end(), input_xshape.begin() + 1);
auto xshape_node = std::make_shared<ngraph::op::Constant>(
input->get_element_type(), input_xshape, std::vector<std::string>{});
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
} }
} }
...@@ -71,7 +77,10 @@ static void BuildTransposeGradNode( ...@@ -71,7 +77,10 @@ static void BuildTransposeGradNode(
ngraph::Shape out_shape; ngraph::Shape out_shape;
if (is_v2) { if (is_v2) {
out_shape = platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape(); auto& xshape =
platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape();
out_shape.resize(xshape.size() - 1);
std::copy(xshape.begin() + 1, xshape.end(), out_shape.begin());
} else { } else {
out_shape = platform::GetInputNode(op, "X", ngb_node_map)->get_shape(); out_shape = platform::GetInputNode(op, "X", ngb_node_map)->get_shape();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册