未验证 提交 8e2dea57 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15538 from baojun-nervana/mv_ng_bridge_file

move ngraph_bridge to ngraph directory 
......@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
if(WITH_NGRAPH)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
endif(WITH_NGRAPH)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......
if(WITH_NGRAPH)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto)
op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context)
endif()
......@@ -17,39 +17,39 @@ limitations under the License. */
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_ops.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace framework {
namespace operators {
namespace NG_OPS = paddle::operators::ngraphs;
std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&,
std::function<void(const std::shared_ptr<framework::OperatorBase>&,
std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = {
{"elementwise_add", NG_OPS::BuildElementwiseAddNode},
{"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode},
{"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode},
{"mean", paddle::operators::ngraphs::BuildMeanNode},
{"mean_grad", paddle::operators::ngraphs::BuildMeanGradNode},
{"mul", paddle::operators::ngraphs::BuildMulNode},
{"mul_grad", paddle::operators::ngraphs::BuildMulGradNode},
{"softmax", paddle::operators::ngraphs::BuildSoftmaxNode},
{"softmax_grad", paddle::operators::ngraphs::BuildSoftmaxGradNode},
{"scale", paddle::operators::ngraphs::BuildScaleNode},
{"relu", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Relu>},
{"tanh", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Tanh>},
{"top_k", paddle::operators::ngraphs::BuildTopKNode}};
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
{"fill_constant", NG_OPS::BuildFillConstantNode},
{"mean", NG_OPS::BuildMeanNode},
{"mean_grad", NG_OPS::BuildMeanGradNode},
{"mul", NG_OPS::BuildMulNode},
{"mul_grad", NG_OPS::BuildMulGradNode},
{"softmax", NG_OPS::BuildSoftmaxNode},
{"softmax_grad", NG_OPS::BuildSoftmaxGradNode},
{"scale", NG_OPS::BuildScaleNode},
{"relu", NG_OPS::BuildUnaryNode<ngraph::op::Relu>},
{"tanh", NG_OPS::BuildUnaryNode<ngraph::op::Tanh>},
{"top_k", NG_OPS::BuildTopKNode}};
void NgraphBridge::BuildNgNode(
const std::shared_ptr<framework::OperatorBase>& op) {
auto& op_type = op->Type();
NG_NODE_MAP[op_type](op, ngb_node_map_);
}
} // namespace framework
} // namespace operators
} // namespace paddle
......@@ -21,16 +21,16 @@ limitations under the License. */
#include "ngraph/node.hpp"
namespace paddle {
namespace framework {
#include "paddle/fluid/framework/operator.h"
class OperatorBase;
namespace paddle {
namespace operators {
class NgraphBridge {
public:
static std::map<
std::string,
std::function<void(const std::shared_ptr<OperatorBase>&,
std::function<void(const std::shared_ptr<framework::OperatorBase>&,
std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>>
NG_NODE_MAP;
......@@ -41,7 +41,7 @@ class NgraphBridge {
var_node_map)
: ngb_node_map_(var_node_map) {}
void BuildNgNode(const std::shared_ptr<OperatorBase>& op);
void BuildNgNode(const std::shared_ptr<framework::OperatorBase>& op);
private:
std::shared_ptr<
......@@ -49,5 +49,5 @@ class NgraphBridge {
ngb_node_map_;
};
} // namespace framework
} // namespace operators
} // namespace paddle
......@@ -24,11 +24,11 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h"
namespace paddle {
......@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int pivot = left;
while (pivot < right) {
auto op_type = ops.at(pivot)->Type();
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) ==
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) {
if (NgraphBridge::NG_NODE_MAP.find(op_type) ==
NgraphBridge::NG_NODE_MAP.end()) {
++pivot;
} else {
int start = pivot, end = start;
while (pivot < right &&
(paddle::framework::NgraphBridge::NG_NODE_MAP.find(
ops.at(pivot)->Type()) !=
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
(NgraphBridge::NG_NODE_MAP.find(ops.at(pivot)->Type()) !=
NgraphBridge::NG_NODE_MAP.end())) {
++pivot;
++end;
}
......@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() {
}
}
}
framework::NgraphBridge ngb(var_node_map_);
NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) {
ngb.BuildNgNode(op);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册