提交 298ee7d2 编写于 作者: B baojun 提交者: Tao Luo

Improve ngraph file line coverage (#22155)

上级 d0f0a252
...@@ -177,36 +177,6 @@ std::string SerializedBlock(const framework::BlockDesc& bdesc) { ...@@ -177,36 +177,6 @@ std::string SerializedBlock(const framework::BlockDesc& bdesc) {
return block_desc.Proto()->SerializeAsString(); return block_desc.Proto()->SerializeAsString();
} }
std::string GenerateEngineKey(const framework::BlockDesc& bdesc) {
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto& op_desc : bdesc.AllOps()) {
auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto();
}
auto engine_key = std::to_string(
std::hash<std::string>()(block_desc.Proto()->SerializeAsString()));
return engine_key;
}
std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs,
const std::vector<std::string>& engine_outputs,
int size) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
engine_hash_key += std::to_string(size);
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void NgraphEngine::FuseNgraphOps( void NgraphEngine::FuseNgraphOps(
const framework::BlockDesc& block_desc, const framework::BlockDesc& block_desc,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) { std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
......
...@@ -40,23 +40,8 @@ static void BuildCastNode( ...@@ -40,23 +40,8 @@ static void BuildCastNode(
auto out = std::make_shared<ngraph::op::Convert>(input, ng_dtype); auto out = std::make_shared<ngraph::op::Convert>(input, ng_dtype);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
} }
static void BuildCastGradNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto op_attrs = framework::AttrReader(op->Attrs());
auto ng_dtype =
platform::GetNgType(static_cast<paddle::framework::proto::VarType::Type>(
op_attrs.Get<int>("out_dtype")));
auto out = std::make_shared<ngraph::op::Convert>(input, ng_dtype);
platform::SetOutputNode(op, "X@GRAD", out, ngb_node_map);
}
} // namespace ngraphs } // namespace ngraphs
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_NG_OP(cast, BuildCastNode); REGISTER_NG_OP(cast, BuildCastNode);
REGISTER_NG_OP(cast_grad, BuildCastGradNode);
...@@ -37,9 +37,7 @@ void BuildElementwiseBinaryNode( ...@@ -37,9 +37,7 @@ void BuildElementwiseBinaryNode(
std::shared_ptr<ngraph::Node>& x = nodes.at(0); std::shared_ptr<ngraph::Node>& x = nodes.at(0);
std::shared_ptr<ngraph::Node>& y = nodes.at(1); std::shared_ptr<ngraph::Node>& y = nodes.at(1);
if (x->get_element_type() != y->get_element_type()) { y = std::make_shared<ngraph::op::Convert>(y, x->get_element_type());
y = std::make_shared<ngraph::op::Convert>(y, x->get_element_type());
}
auto out = std::make_shared<T>(x, y); auto out = std::make_shared<T>(x, y);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
} }
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/op_bridge.h" #include "paddle/fluid/operators/ngraph/ops/op_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h" #include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle { namespace paddle {
...@@ -60,20 +61,16 @@ static void BuildReshapeNode( ...@@ -60,20 +61,16 @@ static void BuildReshapeNode(
std::shared_ptr<ngraph::Node> shape = std::shared_ptr<ngraph::Node> shape =
platform::GetInputNode(op, "Shape", ngb_node_map); platform::GetInputNode(op, "Shape", ngb_node_map);
PADDLE_ENFORCE_EQ(shape, nullptr,
platform::errors::Unimplemented(
"Support for Shape input is not implemented"));
auto op_attrs = framework::AttrReader(op->Attrs()); auto op_attrs = framework::AttrReader(op->Attrs());
std::vector<int> v_shape = op_attrs.Get<std::vector<int>>("shape"); std::vector<int> v_shape = op_attrs.Get<std::vector<int>>("shape");
auto out = input;
if (shape != nullptr) { auto out_shape = calc_output_shape(input_shape, v_shape);
ngraph::Shape new_shape; auto out = platform::NgReshaper(input, out_shape);
for (auto& it : shape->get_shape()) { platform::SetOutputNode(op, "Out", out, ngb_node_map);
new_shape.push_back(it);
}
out = platform::NgReshaper(input, shape->get_shape());
} else {
auto out_shape = calc_output_shape(input_shape, v_shape);
out = platform::NgReshaper(input, out_shape);
}
if (is_v2) { if (is_v2) {
ngraph::Shape input_xshape(input_shape.size() + 1); ngraph::Shape input_xshape(input_shape.size() + 1);
...@@ -83,7 +80,6 @@ static void BuildReshapeNode( ...@@ -83,7 +80,6 @@ static void BuildReshapeNode(
input->get_element_type(), input_xshape, std::vector<std::string>{}); input->get_element_type(), input_xshape, std::vector<std::string>{});
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map); platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
} }
platform::SetOutputNode(op, "Out", out, ngb_node_map);
} }
template <bool is_v2> template <bool is_v2>
......
...@@ -14,7 +14,9 @@ limitations under the License. */ ...@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
...@@ -34,19 +36,18 @@ void BuildSumNode( ...@@ -34,19 +36,18 @@ void BuildSumNode(
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
op_inputs.push_back(var_name); op_inputs.push_back(var_name);
if (ngb_node_map->find(var_name) == ngb_node_map->end()) { PADDLE_ENFORCE_NE(
PADDLE_THROW("op % input varname %s is not found in var_node_map", ngb_node_map->find(var_name), ngb_node_map->end(),
op->Type(), var_name); platform::errors::NotFound(
} "op %s input varname %s is not found in var_node_map", op->Type(),
var_name));
} }
} }
std::shared_ptr<ngraph::Node>& sum = ngb_node_map->at(op_inputs[0]); std::shared_ptr<ngraph::Node>& sum = ngb_node_map->at(op_inputs[0]);
for (size_t k = 1; k < op_inputs.size(); ++k) { for (size_t k = 1; k < op_inputs.size(); ++k) {
std::shared_ptr<ngraph::Node>& nodek = ngb_node_map->at(op_inputs[k]); std::shared_ptr<ngraph::Node>& nodek = ngb_node_map->at(op_inputs[k]);
if (nodek->get_element_type() != sum->get_element_type()) { nodek =
nodek = std::make_shared<ngraph::op::Convert>(nodek, sum->get_element_type());
std::make_shared<ngraph::op::Convert>(nodek, sum->get_element_type());
}
sum = sum + nodek; sum = sum + nodek;
} }
platform::SetOutputNode(op, "Out", sum, ngb_node_map); platform::SetOutputNode(op, "Out", sum, ngb_node_map);
......
...@@ -17,7 +17,7 @@ from __future__ import print_function ...@@ -17,7 +17,7 @@ from __future__ import print_function
import unittest import unittest
import sys import sys
sys.path.append("../") sys.path.append("../")
import test_compare_op from test_compare_op import *
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -18,11 +18,7 @@ import unittest, sys ...@@ -18,11 +18,7 @@ import unittest, sys
sys.path.append("../") sys.path.append("../")
import numpy as np import numpy as np
from test_logical_op import create_test_class from test_logical_op import *
create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b))
create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b))
create_test_class('logical_not', lambda _a: np.logical_not(_a), False)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册