提交 256ba7cb 编写于 作者: B baojun 提交者: tensor-tang

[NGraph] handle dim element 0 of ngraph op (#18568)

上级 a6d468a2
......@@ -38,14 +38,8 @@ namespace operators {
static ngraph::Shape Ddim2Shape(const framework::DDim& dims) {
ngraph::Shape sp;
if (dims.size() == 1 && dims[0] == 0) {
sp.emplace_back(0);
return sp;
}
for (int i = 0; i < dims.size(); ++i) {
int k = dims[i];
k = k == 0 ? 1 : k;
sp.emplace_back(k);
sp.emplace_back(dims[i]);
}
return sp;
}
......@@ -639,17 +633,7 @@ void NgraphEngine::Run(const framework::Scope& scope,
for (auto& op : fused_ops_) {
framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), scope_);
if (op->Type() == "reshape2_grad") {
auto xshape_name = op->Inputs().at("XShape").at(0);
auto* xshape_var = scope_.FindVar(xshape_name);
auto* xshape_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*xshape_var);
auto& xshape_ddim = xshape_tensor->dims();
auto xgrad_name = op->Outputs().at(framework::GradVarName("X")).at(0);
auto* xgrad_var = scope_.FindVar(xgrad_name);
xgrad_var->GetMutable<framework::LoDTensor>()->Resize(xshape_ddim);
} else {
op->RuntimeInferShape(scope_, place_, ctx);
}
op->RuntimeInferShape(scope_, place_, ctx);
}
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> t_out = {};
......
......@@ -34,7 +34,15 @@ void BuildGatherNode(
ngb_node_map) {
auto x = platform::GetInputNode(op, "X", ngb_node_map);
PADDLE_ENFORCE_NOT_NULL(x);
auto index = platform::GetInputNode(op, "Index", ngb_node_map);
auto& index_shape = index->get_shape();
PADDLE_ENFORCE(index_shape.size() == 1 ||
(index_shape.size() == 2 && index_shape[1] == 1));
if (index_shape.size() == 2) {
index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]});
}
auto out = std::make_shared<ngraph::op::Gather>(x, index);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
......@@ -47,7 +55,14 @@ void BuildGatherGradNode(
auto dout = platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
PADDLE_ENFORCE_NOT_NULL(dout);
auto x = platform::GetInputNode(op, "X", ngb_node_map);
auto index = platform::GetInputNode(op, "Index", ngb_node_map);
auto& index_shape = index->get_shape();
PADDLE_ENFORCE(index_shape.size() == 1 ||
(index_shape.size() == 2 && index_shape[1] == 1));
if (index_shape.size() == 2) {
index = platform::NgReshaper(index, ngraph::Shape{index_shape[0]});
}
std::shared_ptr<ngraph::Node> x0 = paddle::platform::CreateConstant(
dout->get_element_type(), x->get_shape(), {0});
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
......@@ -76,7 +77,12 @@ static void BuildReshapeNode(
}
if (is_v2) {
platform::SetOutputNode(op, "XShape", input, ngb_node_map);
ngraph::Shape input_xshape(input_shape.size() + 1);
input_xshape[0] = 0;
std::copy(input_shape.begin(), input_shape.end(), input_xshape.begin() + 1);
auto xshape_node = std::make_shared<ngraph::op::Constant>(
input->get_element_type(), input_xshape, std::vector<std::string>{});
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
}
platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
......@@ -88,13 +94,17 @@ void BuildReshapeGradNode(
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
std::shared_ptr<ngraph::Node> input;
ngraph::Shape out_shape;
if (is_v2) {
input = paddle::platform::GetInputNode(op, "XShape", ngb_node_map);
auto& xshape =
platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape();
out_shape.resize(xshape.size() - 1);
std::copy(xshape.begin() + 1, xshape.end(), out_shape.begin());
} else {
input = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map);
out_shape = input->get_shape();
}
auto dx = platform::NgReshaper(dout, input->get_shape());
auto dx = platform::NgReshaper(dout, out_shape);
paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map);
}
} // namespace ngraphs
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
......@@ -50,7 +51,12 @@ static void BuildTransposeNode(
x_transpose = platform::NgReshaper(x_transpose, x_reshape_shape);
platform::SetOutputNode(op, "Out", x_transpose, ngb_node_map);
if (is_v2) {
platform::SetOutputNode(op, "XShape", input, ngb_node_map);
ngraph::Shape input_xshape(input_shape.size() + 1);
input_xshape[0] = 0;
std::copy(input_shape.begin(), input_shape.end(), input_xshape.begin() + 1);
auto xshape_node = std::make_shared<ngraph::op::Constant>(
input->get_element_type(), input_xshape, std::vector<std::string>{});
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
}
}
......@@ -71,7 +77,10 @@ static void BuildTransposeGradNode(
ngraph::Shape out_shape;
if (is_v2) {
out_shape = platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape();
auto& xshape =
platform::GetInputNode(op, "XShape", ngb_node_map)->get_shape();
out_shape.resize(xshape.size() - 1);
std::copy(xshape.begin() + 1, xshape.end(), out_shape.begin());
} else {
out_shape = platform::GetInputNode(op, "X", ngb_node_map)->get_shape();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册