提交 f2ad30c4 编写于 作者: B baojun 提交者: Tao Luo

Some ngraph op and unittest fix (#19515)

* update ngraph ops test=develop

* update unittest test=develop

* increase coverage test=develop
上级 49523ea1
......@@ -44,8 +44,7 @@ bool NgraphBridge::isSupported(
if (!isRegister(op_type)) {
if (skip_op_list.count(op_type)) {
if (op_type == "lookup_table" || op_type == "lookup_table_grad") {
if (op_attrs.Get<bool>("is_sparse") ||
(op_attrs.Get<int64_t>("padding_idx") != kNoPadding)) {
if (op_attrs.Get<bool>("is_sparse")) {
result = false;
}
} else if ((op_type == "reshape") || (op_type == "reshape2")) {
......
......@@ -39,7 +39,10 @@ void BuildConcatNode(
}
}
auto op_attrs = framework::AttrReader(op->Attrs());
const size_t axis = op_attrs.Get<int>("axis");
int axis = op_attrs.Get<int>("axis");
if (axis < 0) {
axis = axis + args[0]->get_shape().size();
}
auto out = std::make_shared<ngraph::op::Concat>(args, axis);
platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
......
......@@ -80,7 +80,7 @@ std::shared_ptr<ngraph::Node> GroupedGradConvolutionFilter(
auto data_slice = std::make_shared<ngraph::op::Slice>(
data_batch, lower_bound, upper_bound);
size_t filter_step = data_shape.at(0);
size_t filter_step = filter_shape.at(0) / groups;
const std::vector<size_t> filter_lower_bound{i * filter_step, 0, 0, 0};
const std::vector<size_t> filter_upper_bound{
......@@ -127,7 +127,7 @@ std::shared_ptr<ngraph::Node> GroupedGradConvolutionData(
auto data_slice = std::make_shared<ngraph::op::Slice>(
data_batch, lower_bound, upper_bound);
size_t filter_step = data_shape.at(0);
size_t filter_step = filter_shape.at(0) / groups;
const std::vector<size_t> filter_lower_bound{i * filter_step, 0, 0, 0};
const std::vector<size_t> filter_upper_bound{
......
......@@ -29,7 +29,7 @@ namespace ngraphs {
std::shared_ptr<ngraph::Node> remove_trailing_one(
const std::shared_ptr<ngraph::Node>& input) {
auto shape = input->get_shape();
if (shape.back() == 1) {
if (shape.back() == 1 && shape.size() > 1) {
shape.pop_back();
return platform::NgReshaper(input, shape);
} else {
......@@ -73,6 +73,7 @@ std::shared_ptr<ngraph::Node> create_xe(
shape.back() = 1;
return platform::NgReshaper(-node_sum, shape);
}
std::shared_ptr<ngraph::Node> create_mask(
const std::shared_ptr<ngraph::Node>& label, int ignore_index) {
auto ignore_node = paddle::platform::CreateConstant(
......
......@@ -41,6 +41,7 @@ static void BuildDropoutNode(
op_attrs.Get<std::string>("dropout_implementation");
auto is_test = op_attrs.Get<bool>("is_test");
auto seed = op_attrs.Get<int>("seed");
auto fix_seed = op_attrs.Get<bool>("fix_seed");
float value = 1.0f - dropout_prob;
bool upscale_in_train = (dropout_implementation == "upscale_in_train");
......@@ -58,7 +59,8 @@ static void BuildDropoutNode(
ngraph::Shape{}, {1});
auto gen_mask = std::make_shared<ngraph::op::GenerateMask>(
one, input->get_shape(), input->get_element_type(), seed, value);
one, input->get_shape(), input->get_element_type(), seed, value,
fix_seed);
if (upscale_in_train) {
auto mask_val = paddle::platform::CreateConstant(
......
......@@ -47,16 +47,27 @@ void BuildLookupTableNode(
if (is_sparse) {
PADDLE_THROW("Sparsity is not yet supported in nGraph lookup_table op.");
}
auto ng_w_mask = ng_w;
if (padding_idx != kNoPadding) {
PADDLE_THROW("Padding is not yet supported in nGraph lookup_table op.");
auto w_shape = ng_w->get_shape();
std::vector<int> maskV(w_shape[0], 1);
maskV[padding_idx] = 0;
auto maskV_node = std::make_shared<ngraph::op::Constant>(
ng_w->get_element_type(), ngraph::Shape{w_shape[0]}, maskV);
ngraph::AxisSet axis_set;
for (unsigned int i = 1; i < w_shape.size(); ++i) axis_set.insert(i);
auto maskV_bd =
std::make_shared<ngraph::op::Broadcast>(maskV_node, w_shape, axis_set);
ng_w_mask = std::make_shared<ngraph::op::Multiply>(ng_w, maskV_bd);
}
auto shape = ng_ids->get_shape();
if (shape.back() == 1) {
shape.pop_back();
ng_ids = platform::NgReshaper(ng_ids, shape);
}
auto ng_lookup = std::make_shared<ngraph::op::Gather>(ng_w, ng_ids);
auto ng_lookup = std::make_shared<ngraph::op::Gather>(ng_w_mask, ng_ids);
platform::SetOutputNode(op, "Out", ng_lookup, ngb_node_map);
}
......@@ -67,8 +78,6 @@ void BuildLookupTableGradNode(
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
const bool is_sparse = op_attrs.Get<bool>("is_sparse");
const int64_t padding_idx = op_attrs.Get<int64_t>("padding_idx");
auto ng_ids = paddle::platform::GetInputNode(op, "Ids", ngb_node_map);
PADDLE_ENFORCE_NOT_NULL(ng_ids);
......@@ -81,9 +90,6 @@ void BuildLookupTableGradNode(
PADDLE_THROW("Sparsity is not yet supported in nGraph lookup_table op.");
}
if (padding_idx != kNoPadding) {
PADDLE_THROW("Padding is not yet supported in nGraph lookup_table op.");
}
auto shape = ng_ids->get_shape();
if (shape.back() == 1) {
shape.pop_back();
......
......@@ -57,8 +57,18 @@ void BuildSliceNode(
ng_end[axes[i]] = end;
}
auto out = std::make_shared<ngraph::op::Slice>(input, ng_start, ng_end);
platform::SetOutputNode(op, "Out", out, ngb_node_map);
auto out_shape = out->get_shape();
std::vector<size_t> out_axis_vec(out_shape.size());
std::iota(out_axis_vec.begin(), out_axis_vec.end(), 0);
paddle::platform::TrimTrailingSingularDims(&out_shape);
auto out_dim = std::make_shared<ngraph::op::Reshape>(
out, ngraph::AxisVector(out_axis_vec), ngraph::Shape(out_shape));
platform::SetOutputNode(op, "Out", out_dim, ngb_node_map);
}
void BuildSliceGradNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
......
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest, sys
sys.path.append("../")
from test_assign_op import *
from test_assign_op import TestAssignOp
if __name__ == '__main__':
unittest.main()
......@@ -15,7 +15,7 @@
from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3
from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3, TestConcatOp4, TestConcatOp5
if __name__ == '__main__':
unittest.main()
......@@ -15,7 +15,7 @@
from __future__ import print_function
import unittest, sys
sys.path.append("../")
from test_lookup_table_op import *
from test_lookup_table_op import TestLookupTableOp, TestLookupTableOpWithTensorIds, TestLookupTableOpWithPadding, TestLookupTableOpWithTensorIdsAndPadding, TestLookupTableWIsSelectedRows, TestLookupTableWithTensorIdsWIsSelectedRows
if __name__ == "__main__":
unittest.main()
......@@ -17,7 +17,7 @@ from __future__ import print_function
import unittest, sys
sys.path.append("../")
from test_reshape_op import TestReshapeOp, TestReshapeOpDimInfer1, TestReshapeOpDimInfer2, TestReshapeOpWithInputShape
from test_reshape_op import TestReshapeOp, TestReshapeOpDimInfer1, TestReshapeOpDimInfer2
if __name__ == '__main__':
unittest.main()
......@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest, sys
sys.path.append("../")
from test_slice_op import TestSliceOp, TestCase1, TestCase2
from test_slice_op import TestSliceOp, TestSliceOp_decs_dim, TestSliceOp_decs_dim_2, TestSliceOp_decs_dim_3, TestSliceOp_decs_dim_5, TestSliceOp_decs_dim_6, TestCase1, TestCase2
if __name__ == '__main__':
unittest.main()
......@@ -664,6 +664,12 @@ class OpTest(unittest.TestCase):
warnings.warn(
"check inplace_grad for ops using mkldnn is not supported")
return
use_ngraph = fluid.core.is_compiled_with_ngraph(
) and fluid.core.get_flags_use_ngraph()
if use_ngraph:
warnings.warn(
"check inplace_grad for ops using ngraph is not supported")
return
self.check_inplace_grad_output_with_place(
place, no_check_set=no_check_set, inplace_atol=inplace_atol)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册