提交 18e65b0c 编写于 作者: D dongzhihong

"support net_proto header"

上级 983577dc
...@@ -18,4 +18,4 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch ...@@ -18,4 +18,4 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto) proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
#cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto) cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto)
...@@ -11,7 +11,10 @@ void PlainNet::InferShape(Scope* scope) { ...@@ -11,7 +11,10 @@ void PlainNet::InferShape(Scope* scope) {
} }
} }
void PlainNet::Run(Scope* scope) const {} void PlainNet::Run(Scope* scope, DeviceContext* ctx) {
for (auto& op : ops_) {
op.Run(ctx);
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,9 +17,11 @@ ...@@ -17,9 +17,11 @@
#include "paddle/framework/net_proto.pb.h" #include "paddle/framework/net_proto.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using namespace paddle::platform;
// operator's index stored in a network. // operator's index stored in a network.
typedef int OpIndex; typedef int OpIndex;
...@@ -30,15 +32,13 @@ typedef int OpIndex; ...@@ -30,15 +32,13 @@ typedef int OpIndex;
*/ */
struct OpDesc; struct OpDesc;
struct OpDef;
struct OpContext;
struct OpAttrs {}; struct OpAttrs {};
class Operator { class Operator {
public: public:
Operator(const OpDesc &def) {} Operator(const OpDesc &def) {}
void InferShape() {} void InferShape() {}
void Run() {} void Run(DeviceContext *ctx) {}
}; };
/** /**
...@@ -69,12 +69,12 @@ class Net { ...@@ -69,12 +69,12 @@ class Net {
* environment for ops. `begin` and `end` specify the scope of `ops_` to run, * environment for ops. `begin` and `end` specify the scope of `ops_` to run,
* If no positive indexes are provided, all operators in `ops_` will run. * If no positive indexes are provided, all operators in `ops_` will run.
*/ */
virtual void Run(Scope *scope) const = 0; virtual void Run(Scope *scope, DeviceContext *ctx) = 0;
/** /**
* @brief Add an Operator according to `def`. * @brief Add an Operator according to `def`.
*/ */
virtual OpIndex AddOp(const OpDef &def) = 0; virtual OpIndex AddOp(const OpProto &def) = 0;
/** /**
* @brief Add optimizer operators acctording to `attrs`. * @brief Add optimizer operators acctording to `attrs`.
...@@ -123,12 +123,12 @@ class PlainNet : public Net { ...@@ -123,12 +123,12 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context * scope will be used instead. If no OpContext is provicded, default context
* will be used. * will be used.
*/ */
virtual void Run(Scope *scope) const override; virtual void Run(Scope *scope, DeviceContext *ctx) override;
/** /**
* @brief Add an operator to this network. * @brief Add an operator to this network.
*/ */
virtual OpIndex AddOp(const OpDef &def) override; virtual OpIndex AddOp(const OpProto &def) override;
/** /**
* @brief Add all optimizer operators related into the network. * @brief Add all optimizer operators related into the network.
......
...@@ -19,6 +19,6 @@ ...@@ -19,6 +19,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class FakeFC : public OpBase {} class FakeFC : public Operator {}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -36,6 +36,7 @@ class DeviceContext { ...@@ -36,6 +36,7 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext {}; class CPUDeviceContext : public DeviceContext {};
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
class GPUPlaceGuard { class GPUPlaceGuard {
public: public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) { explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册