提交 18e65b0c 编写于 作者: D dongzhihong

"support net_proto header"

上级 983577dc
......@@ -18,4 +18,4 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch
add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
#cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto)
cc_library(net SRCS net.cc DEPS net_proto attr_type op_proto)
......@@ -11,7 +11,10 @@ void PlainNet::InferShape(Scope* scope) {
}
}
void PlainNet::Run(Scope* scope) const {}
void PlainNet::Run(Scope* scope, DeviceContext* ctx) {
for (auto& op : ops_) {
op.Run(ctx);
}
}
} // namespace framework
} // namespace paddle
......@@ -17,9 +17,11 @@
#include "paddle/framework/net_proto.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
using namespace paddle::platform;
// operator's index stored in a network.
typedef int OpIndex;
......@@ -30,15 +32,13 @@ typedef int OpIndex;
*/
struct OpDesc;
struct OpDef;
struct OpContext;
struct OpAttrs {};
class Operator {
public:
Operator(const OpDesc &def) {}
void InferShape() {}
void Run() {}
void Run(DeviceContext *ctx) {}
};
/**
......@@ -69,12 +69,12 @@ class Net {
* environment for ops. `begin` and `end` specify the scope of `ops_` to run,
* If no positive indexes are provided, all operators in `ops_` will run.
*/
virtual void Run(Scope *scope) const = 0;
virtual void Run(Scope *scope, DeviceContext *ctx) = 0;
/**
* @brief Add an Operator according to `def`.
*/
virtual OpIndex AddOp(const OpDef &def) = 0;
virtual OpIndex AddOp(const OpProto &def) = 0;
/**
* @brief Add optimizer operators acctording to `attrs`.
......@@ -123,12 +123,12 @@ class PlainNet : public Net {
* scope will be used instead. If no OpContext is provicded, default context
* will be used.
*/
virtual void Run(Scope *scope) const override;
virtual void Run(Scope *scope, DeviceContext *ctx) override;
/**
* @brief Add an operator to this network.
*/
virtual OpIndex AddOp(const OpDef &def) override;
virtual OpIndex AddOp(const OpProto &def) override;
/**
* @brief Add all optimizer operators related into the network.
......
......@@ -19,6 +19,6 @@
namespace paddle {
namespace framework {
class FakeFC : public OpBase {}
class FakeFC : public Operator {}
} // namespace framework
} // namespace paddle
......@@ -36,6 +36,7 @@ class DeviceContext {
class CPUDeviceContext : public DeviceContext {};
#ifndef PADDLE_ONLY_CPU
class GPUPlaceGuard {
public:
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册