提交 ac4cde00 编写于 作者: B baojun 提交者: tensor-tang

Enable accuracy op for ngraph engine (#15592)

* Added accuracy ngraph op test=develop

* fixed name type test=develop
上级 488719ba
...@@ -31,6 +31,7 @@ std::map<std::string, ...@@ -31,6 +31,7 @@ std::map<std::string,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = { NgraphBridge::NG_NODE_MAP = {
{"accuracy", NG_OPS::BuildAccuracyNode},
{"conv2d", NG_OPS::BuildConv2dNode}, {"conv2d", NG_OPS::BuildConv2dNode},
{"conv2d_grad", NG_OPS::BuildConv2dGradNode}, {"conv2d_grad", NG_OPS::BuildConv2dGradNode},
{"elementwise_add", NG_OPS::BuildElementwiseAddNode}, {"elementwise_add", NG_OPS::BuildElementwiseAddNode},
......
...@@ -21,7 +21,8 @@ limitations under the License. */ ...@@ -21,7 +21,8 @@ limitations under the License. */
#pragma once #pragma once
#include "ops/binary_unnary_op.h" #include "ops/accuracy_op.h"
#include "ops/binary_unary_op.h"
#include "ops/conv2d_op.h" #include "ops/conv2d_op.h"
#include "ops/elementwise_add_op.h" #include "ops/elementwise_add_op.h"
#include "ops/fill_constant_op.h" #include "ops/fill_constant_op.h"
......
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
void BuildAccuracyNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto indices = platform::GetInputNode(op, "Indices", ngb_node_map);
auto label = platform::GetInputNode(op, "Label", ngb_node_map);
auto inference = platform::GetInputNode(op, "Out", ngb_node_map);
auto inference_shape = inference->get_shape();
size_t num_samples = inference_shape.at(0);
size_t k = inference_shape.at(1);
std::shared_ptr<ngraph::Node> label_k = label;
if (k > 1) {
auto label_1d = std::make_shared<ngraph::op::Reshape>(
label, ngraph::AxisVector{0, 1}, ngraph::Shape{num_samples});
label_k = std::make_shared<ngraph::op::Broadcast>(label_1d, inference_shape,
ngraph::AxisSet{1});
}
auto node_equal = std::make_shared<ngraph::op::Equal>(indices, label_k);
auto node_eq_int =
std::make_shared<ngraph::op::Convert>(node_equal, ngraph::element::i64);
auto num_correct_0d =
std::make_shared<ngraph::op::Sum>(node_eq_int, ngraph::AxisSet{0, 1});
std::shared_ptr<ngraph::Node> num_correct =
platform::NgReshaper(num_correct_0d, ngraph::Shape{1});
std::shared_ptr<ngraph::Node> n_samples = ngraph::op::Constant::create(
ngraph::element::i64, ngraph::Shape{1}, {num_samples});
std::shared_ptr<ngraph::Node> accuracy = std::make_shared<ngraph::op::Divide>(
std::make_shared<ngraph::op::Convert>(num_correct, ngraph::element::f32),
std::make_shared<ngraph::op::Convert>(n_samples, ngraph::element::f32));
platform::SetOutputNode(op, "Accuracy", accuracy, ngb_node_map);
platform::SetOutputNode(op, "Correct", num_correct, ngb_node_map);
platform::SetOutputNode(op, "Total", n_samples, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
...@@ -36,11 +36,6 @@ void BuildTopKNode( ...@@ -36,11 +36,6 @@ void BuildTopKNode(
std::make_shared<ngraph::op::GetOutputElement>(top_k, 0); std::make_shared<ngraph::op::GetOutputElement>(top_k, 0);
std::shared_ptr<ngraph::Node> out = std::shared_ptr<ngraph::Node> out =
std::make_shared<ngraph::op::GetOutputElement>(top_k, 1); std::make_shared<ngraph::op::GetOutputElement>(top_k, 1);
auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map);
if (dummy_out && dummy_out->get_element_type() != out->get_element_type()) {
out = std::make_shared<ngraph::op::Convert>(out,
dummy_out->get_element_type());
}
paddle::platform::SetOutputNode(op, "Indices", indices, ngb_node_map); paddle::platform::SetOutputNode(op, "Indices", indices, ngb_node_map);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map); paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
} }
......
...@@ -43,13 +43,14 @@ std::shared_ptr<ngraph::Node> NgReshaper(std::shared_ptr<ngraph::Node> input, ...@@ -43,13 +43,14 @@ std::shared_ptr<ngraph::Node> NgReshaper(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> GetNode( std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const paddle::framework::VariableNameMap& var_map, const std::string name, const paddle::framework::VariableNameMap& var_map,
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) { ngb_node_map) {
auto& var_names = var_map.at(prm); auto& var_names = var_map.at(name);
PADDLE_ENFORCE_EQ(var_names.size(), 1, PADDLE_ENFORCE_EQ(var_names.size(), 1,
"op %s prm %s expects one associated var", op->Type(), prm); "op %s name %s expects one associated var", op->Type(),
name);
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) { if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
return (*ngb_node_map)[var_names[0]]; return (*ngb_node_map)[var_names[0]];
} else { } else {
...@@ -59,43 +60,53 @@ std::shared_ptr<ngraph::Node> GetNode( ...@@ -59,43 +60,53 @@ std::shared_ptr<ngraph::Node> GetNode(
std::shared_ptr<ngraph::Node> GetInputNode( std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const std::string name,
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) { ngb_node_map) {
return GetNode(op, prm, op->Inputs(), ngb_node_map); return GetNode(op, name, op->Inputs(), ngb_node_map);
} }
std::shared_ptr<ngraph::Node> GetOutputNode( std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const std::string name,
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) { ngb_node_map) {
return GetNode(op, prm, op->Outputs(), ngb_node_map); return GetNode(op, name, op->Outputs(), ngb_node_map);
} }
void SetOutputNode( void SetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, std::shared_ptr<ngraph::Node> node, const std::string name, std::shared_ptr<ngraph::Node> node,
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) { ngb_node_map) {
auto& var_names = op->Outputs().at(prm); auto& var_names = op->Outputs().at(name);
if (var_names.size() == 1) { if (var_names.size() == 1) {
/* */
auto dummy_out = GetOutputNode(op, name, ngb_node_map);
if (dummy_out && dummy_out->get_shape() != node->get_shape()) {
node = NgReshaper(node, dummy_out->get_shape());
}
if (dummy_out &&
dummy_out->get_element_type() != node->get_element_type()) {
node = std::make_shared<ngraph::op::Convert>(
node, dummy_out->get_element_type());
}
(*ngb_node_map)[var_names[0]] = node; (*ngb_node_map)[var_names[0]] = node;
} else if (var_names.size() == 0) { } else if (var_names.size() == 0) {
(*ngb_node_map)[""] = node; (*ngb_node_map)[""] = node;
} else { } else {
PADDLE_THROW("prm %s has more than 1 var_names.", prm); PADDLE_THROW("name %s has more than 1 var_names.", name);
} }
} }
bool HasOutput(const std::shared_ptr<paddle::framework::OperatorBase>& op, bool HasOutput(const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm) { const std::string name) {
auto& outputs = op->Outputs(); auto& outputs = op->Outputs();
if (outputs.find(prm) == outputs.end()) return false; if (outputs.find(name) == outputs.end()) return false;
return outputs.at(prm).size() > 0; return outputs.at(name).size() > 0;
} }
inline void GetMidDims(const ngraph::Shape& x_shape, inline void GetMidDims(const ngraph::Shape& x_shape,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_accuracy_op import TestAccuracyOp
class TestNGRAPHAccuracyOp(TestAccuracyOp):
def setUp(self):
super(TestNGRAPHAccuracyOp, self).setUp()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册