提交 298ee7d2 编写于 作者: B baojun 提交者: Tao Luo

Improve ngraph file line coverage (#22155)

上级 d0f0a252
......@@ -177,36 +177,6 @@ std::string SerializedBlock(const framework::BlockDesc& bdesc) {
return block_desc.Proto()->SerializeAsString();
}
std::string GenerateEngineKey(const framework::BlockDesc& bdesc) {
framework::proto::BlockDesc block_proto;
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
for (auto& op_desc : bdesc.AllOps()) {
auto* op = block_desc.AppendOp();
*op->Proto() = *op_desc->Proto();
}
auto engine_key = std::to_string(
std::hash<std::string>()(block_desc.Proto()->SerializeAsString()));
return engine_key;
}
std::string GenerateEngineKey(const std::vector<std::string>& engine_inputs,
const std::vector<std::string>& engine_outputs,
int size) {
std::string engine_hash_key = "";
for (auto name : engine_inputs) {
engine_hash_key += name;
}
for (auto name : engine_outputs) {
engine_hash_key += name;
}
engine_hash_key += std::to_string(size);
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
return engine_key;
}
void NgraphEngine::FuseNgraphOps(
const framework::BlockDesc& block_desc,
std::vector<std::unique_ptr<framework::OperatorBase>>* ops) {
......
......@@ -40,23 +40,8 @@ static void BuildCastNode(
auto out = std::make_shared<ngraph::op::Convert>(input, ng_dtype);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
static void BuildCastGradNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto op_attrs = framework::AttrReader(op->Attrs());
auto ng_dtype =
platform::GetNgType(static_cast<paddle::framework::proto::VarType::Type>(
op_attrs.Get<int>("out_dtype")));
auto out = std::make_shared<ngraph::op::Convert>(input, ng_dtype);
platform::SetOutputNode(op, "X@GRAD", out, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
REGISTER_NG_OP(cast, BuildCastNode);
REGISTER_NG_OP(cast_grad, BuildCastGradNode);
......@@ -37,9 +37,7 @@ void BuildElementwiseBinaryNode(
std::shared_ptr<ngraph::Node>& x = nodes.at(0);
std::shared_ptr<ngraph::Node>& y = nodes.at(1);
if (x->get_element_type() != y->get_element_type()) {
y = std::make_shared<ngraph::op::Convert>(y, x->get_element_type());
}
y = std::make_shared<ngraph::op::Convert>(y, x->get_element_type());
auto out = std::make_shared<T>(x, y);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/op_bridge.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
......@@ -60,20 +61,16 @@ static void BuildReshapeNode(
std::shared_ptr<ngraph::Node> shape =
platform::GetInputNode(op, "Shape", ngb_node_map);
PADDLE_ENFORCE_EQ(shape, nullptr,
platform::errors::Unimplemented(
"Support for Shape input is not implemented"));
auto op_attrs = framework::AttrReader(op->Attrs());
std::vector<int> v_shape = op_attrs.Get<std::vector<int>>("shape");
auto out = input;
if (shape != nullptr) {
ngraph::Shape new_shape;
for (auto& it : shape->get_shape()) {
new_shape.push_back(it);
}
out = platform::NgReshaper(input, shape->get_shape());
} else {
auto out_shape = calc_output_shape(input_shape, v_shape);
out = platform::NgReshaper(input, out_shape);
}
auto out_shape = calc_output_shape(input_shape, v_shape);
auto out = platform::NgReshaper(input, out_shape);
platform::SetOutputNode(op, "Out", out, ngb_node_map);
if (is_v2) {
ngraph::Shape input_xshape(input_shape.size() + 1);
......@@ -83,7 +80,6 @@ static void BuildReshapeNode(
input->get_element_type(), input_xshape, std::vector<std::string>{});
platform::SetOutputNode(op, "XShape", xshape_node, ngb_node_map);
}
platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
template <bool is_v2>
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/ngraph.hpp"
......@@ -34,19 +36,18 @@ void BuildSumNode(
for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) {
op_inputs.push_back(var_name);
if (ngb_node_map->find(var_name) == ngb_node_map->end()) {
PADDLE_THROW("op % input varname %s is not found in var_node_map",
op->Type(), var_name);
}
PADDLE_ENFORCE_NE(
ngb_node_map->find(var_name), ngb_node_map->end(),
platform::errors::NotFound(
"op %s input varname %s is not found in var_node_map", op->Type(),
var_name));
}
}
std::shared_ptr<ngraph::Node>& sum = ngb_node_map->at(op_inputs[0]);
for (size_t k = 1; k < op_inputs.size(); ++k) {
std::shared_ptr<ngraph::Node>& nodek = ngb_node_map->at(op_inputs[k]);
if (nodek->get_element_type() != sum->get_element_type()) {
nodek =
std::make_shared<ngraph::op::Convert>(nodek, sum->get_element_type());
}
nodek =
std::make_shared<ngraph::op::Convert>(nodek, sum->get_element_type());
sum = sum + nodek;
}
platform::SetOutputNode(op, "Out", sum, ngb_node_map);
......
......@@ -17,7 +17,7 @@ from __future__ import print_function
import unittest
import sys
sys.path.append("../")
import test_compare_op
from test_compare_op import *
if __name__ == '__main__':
unittest.main()
......@@ -18,11 +18,7 @@ import unittest, sys
sys.path.append("../")
import numpy as np
from test_logical_op import create_test_class
create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b))
create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b))
create_test_class('logical_not', lambda _a: np.logical_not(_a), False)
from test_logical_op import *
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册