提交 a42f8f4f 编写于 作者: M mozga-intel

Enable element_wise_add operator for a ngraph

test=develop
上级 9c027651
...@@ -26,11 +26,14 @@ limitations under the License. */ ...@@ -26,11 +26,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace NG_OPS = paddle::operators::ngraphs;
std::map<std::string, std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&, std::function<void(const std::shared_ptr<OperatorBase>&,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = { NgraphBridge::NG_NODE_MAP = {
{"elementwise_add", NG_OPS::BuildElementwiseAddNode},
{"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode},
{"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode}, {"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode},
{"mean", paddle::operators::ngraphs::BuildMeanNode}, {"mean", paddle::operators::ngraphs::BuildMeanNode},
{"mean_grad", paddle::operators::ngraphs::BuildMeanGradNode}, {"mean_grad", paddle::operators::ngraphs::BuildMeanGradNode},
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#pragma once #pragma once
#include "ops/binary_unnary_op.h" #include "ops/binary_unnary_op.h"
#include "ops/elementwise_add_op.h"
#include "ops/fill_constant_op.h" #include "ops/fill_constant_op.h"
#include "ops/mean_op.h" #include "ops/mean_op.h"
#include "ops/mul_op.h" #include "ops/mul_op.h"
......
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/elementwise_node.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
void BuildElementwiseAddNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
BuildElementwiseBinaryNode<ngraph::op::Add>(op, ngb_node_map);
}
void BuildElementwiseAddGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
int axis = op_attrs.Get<int>("axis");
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto dout_shape = dout->get_shape();
auto y_shape = y->get_shape();
if (dout_shape == y_shape) {
paddle::platform::SetOutputNode(op, "X@GRAD", dout, ngb_node_map);
paddle::platform::SetOutputNode(op, "Y@GRAD", dout, ngb_node_map);
} else {
axis = (axis == -1 ? dout_shape.size() - y_shape.size() : axis);
paddle::platform::TrimTrailingSingularDims(&y_shape);
axis = (y_shape.size() == 0 ? dout_shape.size() : axis);
int pre, n, post;
paddle::platform::GetMidDims(dout_shape, y_shape, axis, &pre, &n, &post);
ngraph::Shape lhs_shape{};
lhs_shape.push_back(pre);
lhs_shape.push_back(n);
if (post != 1) {
lhs_shape.push_back(post);
}
std::vector<size_t> lhs_order(dout_shape.size());
std::iota(std::begin(lhs_order), std::end(lhs_order), 0);
auto dout_reshape = std::make_shared<ngraph::op::Reshape>(
dout, ngraph::AxisVector(lhs_order), lhs_shape);
ngraph::AxisSet axis_set{0};
if (post != 1) {
axis_set.insert(2);
}
auto dout_sum = std::make_shared<ngraph::op::Sum>(dout_reshape, axis_set);
auto dy = std::make_shared<ngraph::op::Reshape>(
dout_sum, ngraph::AxisVector{0}, y->get_shape());
paddle::platform::SetOutputNode(op, "X@GRAD", dout, ngb_node_map);
paddle::platform::SetOutputNode(op, "Y@GRAD", dy, ngb_node_map);
}
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
#endif
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
ngraph::NodeVector ElementwiseBinaryNodePrepare(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
int axis = op_attrs.Get<int>("axis");
auto lhs = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto rhs = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto lhs_shape = lhs->get_shape();
auto rhs_shape = rhs->get_shape();
PADDLE_ENFORCE_GE(lhs_shape.size(), rhs_shape.size(),
"Rank of first input must >= rank of second input.");
if (lhs_shape == rhs_shape) {
return ngraph::NodeVector{lhs, rhs};
}
axis = (axis == -1 ? lhs_shape.size() - rhs_shape.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < (int)(lhs_shape.size()),
"Axis should be in range [0, lhs_shape)");
paddle::platform::TrimTrailingSingularDims(&rhs_shape);
axis = (rhs_shape.size() == 0) ? lhs_shape.size() : axis;
int pre, n, post;
paddle::platform::GetMidDims(lhs_shape, rhs_shape, axis, &pre, &n, &post);
ngraph::Shape l_shape{};
l_shape.push_back(pre);
l_shape.push_back(n);
l_shape.push_back(post);
std::vector<size_t> rhs_order(rhs->get_shape().size());
std::iota(std::begin(rhs_order), std::end(rhs_order), 0);
ngraph::Shape r_shape{};
r_shape.push_back(n);
auto rhs_reshape = std::make_shared<ngraph::op::Reshape>(
rhs, ngraph::AxisVector(rhs_order), r_shape);
auto rhs_bcast = std::make_shared<ngraph::op::Broadcast>(
rhs_reshape, l_shape, ngraph::AxisSet{0, 2});
std::vector<size_t> bcast_order(rhs_bcast->get_shape().size());
std::iota(std::begin(bcast_order), std::end(bcast_order), 0);
std::shared_ptr<ngraph::Node> rhs_bcast_reshape =
std::make_shared<ngraph::op::Reshape>(
rhs_bcast, ngraph::AxisVector(bcast_order), lhs_shape);
return ngraph::NodeVector{lhs, rhs_bcast_reshape};
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
#endif
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/elementwise_binary_prepare_node.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
template <typename T>
void BuildElementwiseBinaryNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto nodes = ElementwiseBinaryNodePrepare(op, ngb_node_map);
std::shared_ptr<ngraph::Node>& x = nodes.at(0);
std::shared_ptr<ngraph::Node>& y = nodes.at(1);
if (x->get_element_type() != y->get_element_type()) {
y = std::make_shared<ngraph::op::Convert>(y, x->get_element_type());
}
auto out = std::make_shared<T>(x, y);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
template <typename T>
void BuildElementwiseCompareNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto nodes = ElementwiseBinaryNodePrepare(op, ngb_node_map);
std::shared_ptr<ngraph::Node>& x = nodes.at(0);
std::shared_ptr<ngraph::Node>& y = nodes.at(1);
if (x->get_element_type() != y->get_element_type()) {
x = std::make_shared<ngraph::op::Convert>(x, ngraph::element::f64);
y = std::make_shared<ngraph::op::Convert>(y, ngraph::element::f64);
}
auto out = std::make_shared<T>(x, y);
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
#endif
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
static ngraph::Shape FlattenTo2d(ngraph::Shape sh, int num) { ngraph::Shape FlattenTo2d(ngraph::Shape sh, int num) {
auto x1 = std::accumulate(std::begin(sh), std::begin(sh) + num, 1, auto x1 = std::accumulate(std::begin(sh), std::begin(sh) + num, 1,
std::multiplies<size_t>()); std::multiplies<size_t>());
auto x2 = std::accumulate(std::begin(sh) + num, std::end(sh), 1, auto x2 = std::accumulate(std::begin(sh) + num, std::end(sh), 1,
...@@ -33,15 +33,15 @@ static ngraph::Shape FlattenTo2d(ngraph::Shape sh, int num) { ...@@ -33,15 +33,15 @@ static ngraph::Shape FlattenTo2d(ngraph::Shape sh, int num) {
return ngraph::Shape{x1_l, x2_l}; return ngraph::Shape{x1_l, x2_l};
} }
static std::shared_ptr<ngraph::Node> NgReshaper( std::shared_ptr<ngraph::Node> NgReshaper(std::shared_ptr<ngraph::Node> input,
std::shared_ptr<ngraph::Node> input, ngraph::Shape shape) { ngraph::Shape shape) {
std::vector<size_t> input_order(input->get_shape().size()); std::vector<size_t> input_order(input->get_shape().size());
std::iota(std::begin(input_order), std::end(input_order), 0); std::iota(std::begin(input_order), std::end(input_order), 0);
return std::make_shared<ngraph::op::Reshape>( return std::make_shared<ngraph::op::Reshape>(
input, ngraph::AxisVector(input_order), shape); input, ngraph::AxisVector(input_order), shape);
} }
static std::shared_ptr<ngraph::Node> GetNode( std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const paddle::framework::VariableNameMap& var_map, const std::string prm, const paddle::framework::VariableNameMap& var_map,
std::shared_ptr< std::shared_ptr<
...@@ -57,7 +57,7 @@ static std::shared_ptr<ngraph::Node> GetNode( ...@@ -57,7 +57,7 @@ static std::shared_ptr<ngraph::Node> GetNode(
} }
} }
static std::shared_ptr<ngraph::Node> GetInputNode( std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const std::string prm,
std::shared_ptr< std::shared_ptr<
...@@ -66,7 +66,7 @@ static std::shared_ptr<ngraph::Node> GetInputNode( ...@@ -66,7 +66,7 @@ static std::shared_ptr<ngraph::Node> GetInputNode(
return GetNode(op, prm, op->Inputs(), ngb_node_map); return GetNode(op, prm, op->Inputs(), ngb_node_map);
} }
static std::shared_ptr<ngraph::Node> GetOutputNode( std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const std::string prm,
std::shared_ptr< std::shared_ptr<
...@@ -75,7 +75,7 @@ static std::shared_ptr<ngraph::Node> GetOutputNode( ...@@ -75,7 +75,7 @@ static std::shared_ptr<ngraph::Node> GetOutputNode(
return GetNode(op, prm, op->Outputs(), ngb_node_map); return GetNode(op, prm, op->Outputs(), ngb_node_map);
} }
static void SetOutputNode( void SetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op, const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, std::shared_ptr<ngraph::Node> node, const std::string prm, std::shared_ptr<ngraph::Node> node,
std::shared_ptr< std::shared_ptr<
...@@ -91,14 +91,45 @@ static void SetOutputNode( ...@@ -91,14 +91,45 @@ static void SetOutputNode(
} }
} }
static bool HasOutput( bool HasOutput(const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm) { const std::string prm) {
auto& outputs = op->Outputs(); auto& outputs = op->Outputs();
if (outputs.find(prm) == outputs.end()) return false; if (outputs.find(prm) == outputs.end()) return false;
return outputs.at(prm).size() > 0; return outputs.at(prm).size() > 0;
} }
inline void GetMidDims(const ngraph::Shape& x_shape,
const ngraph::Shape& y_shape, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_shape[i];
}
for (size_t i = 0; i < y_shape.size(); ++i) {
PADDLE_ENFORCE_EQ(x_shape[i + axis], y_shape[i],
"Broadcast dimension mismatch.");
(*n) *= y_shape[i];
}
for (size_t i = axis + y_shape.size(); i < x_shape.size(); ++i) {
(*post) *= x_shape[i];
}
}
inline void TrimTrailingSingularDims(ngraph::Shape* shape) {
// Remove trailing dimensions of size 1 for y
auto actual_shape_size = shape->size();
for (; actual_shape_size != 0; --actual_shape_size) {
if ((*shape)[actual_shape_size - 1] != 1) {
break;
} else {
shape->pop_back();
}
}
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_elementwise_add_op import *
class TestNGRAPHElementwiseAddOp(TestElementwiseAddOp):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp, self).init_input_output()
class TestNGRAPHElementwiseAddOp_scalar(TestElementwiseAddOp_scalar):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_scalar, self).init_input_output()
class TestNGRAPHElementwiseAddOp_scalar2(TestElementwiseAddOp_scalar2):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_scalar2, self).init_input_output()
class TestNGRAPHElementwiseAddOp_Vector(TestElementwiseAddOp_Vector):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_Vector, self).init_input_output()
class TesNGRAPHtElementwiseAddOp_broadcast_0(TestElementwiseAddOp_broadcast_0):
def init_input_output(self):
super(TesNGRAPHtElementwiseAddOp_broadcast_0, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_1(TestElementwiseAddOp_broadcast_1):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_1, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_2(TestElementwiseAddOp_broadcast_2):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_2, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_3(TestElementwiseAddOp_broadcast_3):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_3, self).init_input_output()
class TestNGRAPHElementwiseAddOp_broadcast_4(TestElementwiseAddOp_broadcast_4):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_broadcast_4, self).init_input_output()
class TestNGRAPHElementwiseAddOp_rowwise_add_0(
TestElementwiseAddOp_rowwise_add_0):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_rowwise_add_0,
self).init_input_output()
class TestNGRAPHElementwiseAddOp_rowwise_add_1(
TestElementwiseAddOp_rowwise_add_1):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_rowwise_add_1,
self).init_input_output()
class TestNGRAPHElementwiseAddOp_channelwise_add(
TestElementwiseAddOp_channelwise_add):
def init_input_output(self):
super(TestNGRAPHElementwiseAddOp_channelwise_add,
self).init_input_output()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册