diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 8cb0c4e668a71e1c06e2cf13ad6b25854077e705..2ba2437de66f31549a87f20360dbb97b48ea6fbe 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -129,10 +129,6 @@ cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) -if(WITH_NGRAPH) - cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) -endif(WITH_NGRAPH) - cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) diff --git a/paddle/fluid/operators/ngraph/CMakeLists.txt b/paddle/fluid/operators/ngraph/CMakeLists.txt index 83f78d505d7444cd12105aee40b0d03349b07be3..6b256ef02666c21ec1db3f6922b56bb23363b4a0 100644 --- a/paddle/fluid/operators/ngraph/CMakeLists.txt +++ b/paddle/fluid/operators/ngraph/CMakeLists.txt @@ -1,4 +1,5 @@ if(WITH_NGRAPH) + cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) cc_library(ngraph_engine SRCS ngraph_engine.cc DEPS ngraph_bridge framework_proto) op_library(ngraph_engine_op DEPS ngraph_engine op_registry op_info device_context) endif() diff --git a/paddle/fluid/framework/ngraph_bridge.cc b/paddle/fluid/operators/ngraph/ngraph_bridge.cc similarity index 55% rename from paddle/fluid/framework/ngraph_bridge.cc rename to paddle/fluid/operators/ngraph/ngraph_bridge.cc index 365870c54eb3861ad6c273d3866dcd32d1c4166a..d6e897ed4666261cdd0bd6565f61abb218d971e5 100644 --- a/paddle/fluid/framework/ngraph_bridge.cc +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.cc @@ -17,39 +17,39 @@ limitations under the License. */ #include #include "ngraph/ngraph.hpp" -#include "paddle/fluid/framework/ngraph_bridge.h" -#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/ngraph/ngraph_bridge.h" #include "paddle/fluid/operators/ngraph/ngraph_ops.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/ngraph_helper.h" namespace paddle { -namespace framework { +namespace operators { namespace NG_OPS = paddle::operators::ngraphs; std::map&, + std::function&, std::shared_ptr>>)>> NgraphBridge::NG_NODE_MAP = { {"elementwise_add", NG_OPS::BuildElementwiseAddNode}, {"elementwise_add_grad", NG_OPS::BuildElementwiseAddGradNode}, - {"fill_constant", paddle::operators::ngraphs::BuildFillConstantNode}, - {"mean", paddle::operators::ngraphs::BuildMeanNode}, - {"mean_grad", paddle::operators::ngraphs::BuildMeanGradNode}, - {"mul", paddle::operators::ngraphs::BuildMulNode}, - {"mul_grad", paddle::operators::ngraphs::BuildMulGradNode}, - {"softmax", paddle::operators::ngraphs::BuildSoftmaxNode}, - {"softmax_grad", paddle::operators::ngraphs::BuildSoftmaxGradNode}, - {"scale", paddle::operators::ngraphs::BuildScaleNode}, - {"relu", paddle::operators::ngraphs::BuildUnaryNode}, - {"tanh", paddle::operators::ngraphs::BuildUnaryNode}, - {"top_k", paddle::operators::ngraphs::BuildTopKNode}}; - -void NgraphBridge::BuildNgNode(const std::shared_ptr& op) { + {"fill_constant", NG_OPS::BuildFillConstantNode}, + {"mean", NG_OPS::BuildMeanNode}, + {"mean_grad", NG_OPS::BuildMeanGradNode}, + {"mul", NG_OPS::BuildMulNode}, + {"mul_grad", NG_OPS::BuildMulGradNode}, + {"softmax", NG_OPS::BuildSoftmaxNode}, + {"softmax_grad", NG_OPS::BuildSoftmaxGradNode}, + {"scale", NG_OPS::BuildScaleNode}, + {"relu", NG_OPS::BuildUnaryNode}, + {"tanh", NG_OPS::BuildUnaryNode}, + {"top_k", NG_OPS::BuildTopKNode}}; + +void NgraphBridge::BuildNgNode( + const std::shared_ptr& op) { auto& op_type = op->Type(); NG_NODE_MAP[op_type](op, ngb_node_map_); } -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/fluid/framework/ngraph_bridge.h b/paddle/fluid/operators/ngraph/ngraph_bridge.h similarity index 84% rename from paddle/fluid/framework/ngraph_bridge.h rename to paddle/fluid/operators/ngraph/ngraph_bridge.h index 5ad7b8daeb6a782515e50fc87ca7188b46308390..c57988f8f6322e76678c572aa21ff5b17b9e3c22 100644 --- a/paddle/fluid/framework/ngraph_bridge.h +++ b/paddle/fluid/operators/ngraph/ngraph_bridge.h @@ -21,16 +21,16 @@ limitations under the License. */ #include "ngraph/node.hpp" -namespace paddle { -namespace framework { +#include "paddle/fluid/framework/operator.h" -class OperatorBase; +namespace paddle { +namespace operators { class NgraphBridge { public: static std::map< std::string, - std::function&, + std::function&, std::shared_ptr>>)>> NG_NODE_MAP; @@ -41,7 +41,7 @@ class NgraphBridge { var_node_map) : ngb_node_map_(var_node_map) {} - void BuildNgNode(const std::shared_ptr& op); + void BuildNgNode(const std::shared_ptr& op); private: std::shared_ptr< @@ -49,5 +49,5 @@ class NgraphBridge { ngb_node_map_; }; -} // namespace framework +} // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/ngraph/ngraph_engine.cc b/paddle/fluid/operators/ngraph/ngraph_engine.cc index fde3a5ba55bf13a6dc60cce0915d71f27f640e90..bec4b514a218715134d2366dd7efd7cf5b377b68 100644 --- a/paddle/fluid/operators/ngraph/ngraph_engine.cc +++ b/paddle/fluid/operators/ngraph/ngraph_engine.cc @@ -24,11 +24,11 @@ limitations under the License. */ #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/ngraph_bridge.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/ngraph/ngraph_bridge.h" #include "paddle/fluid/operators/ngraph/ngraph_engine.h" namespace paddle { @@ -88,15 +88,14 @@ static std::vector> NgraphOpIntervals( int pivot = left; while (pivot < right) { auto op_type = ops.at(pivot)->Type(); - if (paddle::framework::NgraphBridge::NG_NODE_MAP.find(op_type) == - paddle::framework::NgraphBridge::NG_NODE_MAP.end()) { + if (NgraphBridge::NG_NODE_MAP.find(op_type) == + NgraphBridge::NG_NODE_MAP.end()) { ++pivot; } else { int start = pivot, end = start; while (pivot < right && - (paddle::framework::NgraphBridge::NG_NODE_MAP.find( - ops.at(pivot)->Type()) != - paddle::framework::NgraphBridge::NG_NODE_MAP.end())) { + (NgraphBridge::NG_NODE_MAP.find(ops.at(pivot)->Type()) != + NgraphBridge::NG_NODE_MAP.end())) { ++pivot; ++end; } @@ -283,7 +282,7 @@ void NgraphEngine::BuildNgNodes() { } } } - framework::NgraphBridge ngb(var_node_map_); + NgraphBridge ngb(var_node_map_); for (auto& op : fused_ops_) { ngb.BuildNgNode(op); }