diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 39cfb4623795e8dff47b9f69a16ee70361dd81ab..e6e3b79d7bd1129e102fde780c8e609a16745e75 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -18,4 +18,4 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) -#cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) +cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 20c0aef049ceb6f0617904eb757a904f83ee4260..f0c128d554b296d7fe5c6818d3911aaee5c0adce 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -11,7 +11,10 @@ void PlainNet::InferShape(Scope* scope) { } } -void PlainNet::Run(Scope* scope) const {} - +void PlainNet::Run(Scope* scope, DeviceContext* ctx) { + for (auto& op : ops_) { + op.Run(ctx); + } +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index ef5013349196a50f9950469f8495c6b3e166a049..b2894320dafdfaf9b8e0bffc8c863a2caae35a61 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -17,9 +17,11 @@ #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { +using namespace paddle::platform; // operator's index stored in a network. typedef int OpIndex; @@ -30,15 +32,13 @@ typedef int OpIndex; */ struct OpDesc; -struct OpDef; -struct OpContext; struct OpAttrs {}; class Operator { public: Operator(const OpDesc &def) {} void InferShape() {} - void Run() {} + void Run(DeviceContext *ctx) {} }; /** @@ -69,12 +69,12 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual void Run(Scope *scope) const = 0; + virtual void Run(Scope *scope, DeviceContext *ctx) = 0; /** * @brief Add an Operator according to `def`. */ - virtual OpIndex AddOp(const OpDef &def) = 0; + virtual OpIndex AddOp(const OpProto &def) = 0; /** * @brief Add optimizer operators acctording to `attrs`. @@ -123,12 +123,12 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(Scope *scope) const override; + virtual void Run(Scope *scope, DeviceContext *ctx) override; /** * @brief Add an operator to this network. */ - virtual OpIndex AddOp(const OpDef &def) override; + virtual OpIndex AddOp(const OpProto &def) override; /** * @brief Add all optimizer operators related into the network. diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc index 04f5efdf79bdcf724accc20a7b6365b003d61ff4..a8e31c1497519ce60da004bc0a3e52403593497c 100644 --- a/paddle/framework/net_test.cc +++ b/paddle/framework/net_test.cc @@ -19,6 +19,6 @@ namespace paddle { namespace framework { -class FakeFC : public OpBase {} +class FakeFC : public Operator {} } // namespace framework } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index fcef0a5e3058f1c9d54f9e06a54a09286e2454fd..160eb4e12060b36c4fefba499d4e83b9aab92848 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -36,6 +36,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext {}; #ifndef PADDLE_ONLY_CPU + class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {