提交 8e9308a5 编写于 作者: B baojun-nervana

mv ngraph_bridge to ngraph directory test=develop

上级 88bd7e1a
...@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version) ...@@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
if(WITH_NGRAPH)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
endif(WITH_NGRAPH)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
......
if(WITH_NGRAPH) if(WITH_NGRAPH)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto) cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto)
op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context) op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context)
endif() endif()
...@@ -17,39 +17,39 @@ limitations under the License. */ ...@@ -17,39 +17,39 @@ limitations under the License. */
#include <vector> #include <vector>
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
#include "paddle/fluid/framework/ngraph_bridge.h" #include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/ngraph/ngraph_ops.h" #include "paddle/fluid/operators/ngraph/ngraph_ops.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h" #include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace operators {
namespace NG_OPS = paddle::operators::ngraphs; namespace NG_OPS = paddle::operators::ngraphs;
std::map<std::string, std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&, std::function<void(const std::shared_ptr<framework::OperatorBase>&,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = { NgraphBridge::NG_NODE_MAP = {
{"elementwise_add", NG_OPS::BuildElementwiseAddNode}, {"elementwise_add", NG_OPS::BuildElementwiseAddNode},
{"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode}, {"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode},
{"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode}, {"fill_constant", NG_OPS::BuildFillConstantNode},
{"mean", paddle::operators::ngraphs::BuildMeanNode}, {"mean", NG_OPS::BuildMeanNode},
{"mean_grad", paddle::operators::ngraphs::BuildMeanGradNode}, {"mean_grad", NG_OPS::BuildMeanGradNode},
{"mul", paddle::operators::ngraphs::BuildMulNode}, {"mul", NG_OPS::BuildMulNode},
{"mul_grad", paddle::operators::ngraphs::BuildMulGradNode}, {"mul_grad", NG_OPS::BuildMulGradNode},
{"softmax", paddle::operators::ngraphs::BuildSoftmaxNode}, {"softmax", NG_OPS::BuildSoftmaxNode},
{"softmax_grad", paddle::operators::ngraphs::BuildSoftmaxGradNode}, {"softmax_grad", NG_OPS::BuildSoftmaxGradNode},
{"scale", paddle::operators::ngraphs::BuildScaleNode}, {"scale", NG_OPS::BuildScaleNode},
{"relu", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Relu>}, {"relu", NG_OPS::BuildUnaryNode<ngraph::op::Relu>},
{"tanh", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Tanh>}, {"tanh", NG_OPS::BuildUnaryNode<ngraph::op::Tanh>},
{"top_k", paddle::operators::ngraphs::BuildTopKNode}}; {"top_k", NG_OPS::BuildTopKNode}};
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) { void NgraphBridge::BuildNgNode(
const std::shared_ptr<framework::OperatorBase>& op) {
auto& op_type = op->Type(); auto& op_type = op->Type();
NG_NODE_MAP[op_type](op, ngb_node_map_); NG_NODE_MAP[op_type](op, ngb_node_map_);
} }
} // namespace framework } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -21,16 +21,16 @@ limitations under the License. */ ...@@ -21,16 +21,16 @@ limitations under the License. */
#include "ngraph/node.hpp" #include "ngraph/node.hpp"
namespace paddle { #include "paddle/fluid/framework/operator.h"
namespace framework {
class OperatorBase; namespace paddle {
namespace operators {
class NgraphBridge { class NgraphBridge {
public: public:
static std::map< static std::map<
std::string, std::string,
std::function<void(const std::shared_ptr<OperatorBase>&, std::function<void(const std::shared_ptr<framework::OperatorBase>&,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NG_NODE_MAP; NG_NODE_MAP;
...@@ -41,7 +41,7 @@ class NgraphBridge { ...@@ -41,7 +41,7 @@ class NgraphBridge {
var_node_map) var_node_map)
: ngb_node_map_(var_node_map) {} : ngb_node_map_(var_node_map) {}
void BuildNgNode(const std::shared_ptr<OperatorBase>& op); void BuildNgNode(const std::shared_ptr<framework::OperatorBase>& op);
private: private:
std::shared_ptr< std::shared_ptr<
...@@ -49,5 +49,5 @@ class NgraphBridge { ...@@ -49,5 +49,5 @@ class NgraphBridge {
ngb_node_map_; ngb_node_map_;
}; };
} // namespace framework } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -24,11 +24,11 @@ limitations under the License. */ ...@@ -24,11 +24,11 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/ngraph/ngraph_bridge.h"
#include "paddle/fluid/operators/ngraph/ngraph_engine.h" #include "paddle/fluid/operators/ngraph/ngraph_engine.h"
namespace paddle { namespace paddle {
...@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals( ...@@ -88,15 +88,14 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int pivot = left; int pivot = left;
while (pivot < right) { while (pivot < right) {
auto op_type = ops.at(pivot)->Type(); auto op_type = ops.at(pivot)->Type();
if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) == if (NgraphBridge::NG_NODE_MAP.find(op_type) ==
paddle::framework::NgraphBridge::NG_NODE_MAP.end()) { NgraphBridge::NG_NODE_MAP.end()) {
++pivot; ++pivot;
} else { } else {
int start = pivot, end = start; int start = pivot, end = start;
while (pivot < right && while (pivot < right &&
(paddle::framework::NgraphBridge::NG_NODE_MAP.find( (NgraphBridge::NG_NODE_MAP.find(ops.at(pivot)->Type()) !=
ops.at(pivot)->Type()) != NgraphBridge::NG_NODE_MAP.end())) {
paddle::framework::NgraphBridge::NG_NODE_MAP.end())) {
++pivot; ++pivot;
++end; ++end;
} }
...@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() { ...@@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() {
} }
} }
} }
framework::NgraphBridge ngb(var_node_map_); NgraphBridge ngb(var_node_map_);
for (auto& op : fused_ops_) { for (auto& op : fused_ops_) {
ngb.BuildNgNode(op); ngb.BuildNgNode(op);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册