From 18e65b0c084ef482492b528985173341a24284cc Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Tue, 11 Jul 2017 10:37:41 +0800 Subject: [PATCH] "support net_proto header" --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/net.cc | 7 +++++-- paddle/framework/net.h | 14 +++++++------- paddle/framework/net_test.cc | 2 +- paddle/platform/device_context.h | 1 + 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 39cfb46237..e6e3b79d7b 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -18,4 +18,4 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) -#cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) +cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 20c0aef049..f0c128d554 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -11,7 +11,10 @@ void PlainNet::InferShape(Scope* scope) { } } -void PlainNet::Run(Scope* scope) const {} - +void PlainNet::Run(Scope* scope, DeviceContext* ctx) { + for (auto& op : ops_) { + op.Run(ctx); + } +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index ef50133491..b2894320da 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -17,9 +17,11 @@ #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/scope.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { +using namespace paddle::platform; // operator's index stored in a network. typedef int OpIndex; @@ -30,15 +32,13 @@ typedef int OpIndex; */ struct OpDesc; -struct OpDef; -struct OpContext; struct OpAttrs {}; class Operator { public: Operator(const OpDesc &def) {} void InferShape() {} - void Run() {} + void Run(DeviceContext *ctx) {} }; /** @@ -69,12 +69,12 @@ class Net { * environment for ops. `begin` and `end` specify the scope of `ops_` to run, * If no positive indexes are provided, all operators in `ops_` will run. */ - virtual void Run(Scope *scope) const = 0; + virtual void Run(Scope *scope, DeviceContext *ctx) = 0; /** * @brief Add an Operator according to `def`. */ - virtual OpIndex AddOp(const OpDef &def) = 0; + virtual OpIndex AddOp(const OpProto &def) = 0; /** * @brief Add optimizer operators acctording to `attrs`. @@ -123,12 +123,12 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(Scope *scope) const override; + virtual void Run(Scope *scope, DeviceContext *ctx) override; /** * @brief Add an operator to this network. */ - virtual OpIndex AddOp(const OpDef &def) override; + virtual OpIndex AddOp(const OpProto &def) override; /** * @brief Add all optimizer operators related into the network. diff --git a/paddle/framework/net_test.cc b/paddle/framework/net_test.cc index 04f5efdf79..a8e31c1497 100644 --- a/paddle/framework/net_test.cc +++ b/paddle/framework/net_test.cc @@ -19,6 +19,6 @@ namespace paddle { namespace framework { -class FakeFC : public OpBase {} +class FakeFC : public Operator {} } // namespace framework } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index fcef0a5e30..160eb4e120 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -36,6 +36,7 @@ class DeviceContext { class CPUDeviceContext : public DeviceContext {}; #ifndef PADDLE_ONLY_CPU + class GPUPlaceGuard { public: explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { -- GitLab