未验证 提交 ebc73039 编写于 作者: W Wu Yi 提交者: GitHub

listen_and_serv use local scope (#10663)

* listen_and_serv use localscope

* fix ut
上级 868bdc97
......@@ -228,7 +228,8 @@ static bool has_fetch_operators(
void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_vars, const std::string& feed_holder_name,
bool create_local_scope, bool create_vars,
const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
......@@ -290,8 +291,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
auto ctx = Prepare(*copy_program, 0);
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars,
feed_holder_name, fetch_holder_name);
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
create_local_scope, create_vars, feed_holder_name,
fetch_holder_name);
}
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
......@@ -366,8 +368,9 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets, bool create_vars,
const std::string& feed_holder_name, const std::string& fetch_holder_name) {
std::map<std::string, LoDTensor*>* fetch_targets, bool create_local_scope,
bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
auto& global_block = ctx->prog_.Block(ctx->block_id_);
PADDLE_ENFORCE(
......@@ -387,7 +390,7 @@ void Executor::RunPreparedContext(
}
}
RunPreparedContext(ctx, scope, create_vars, create_vars);
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block.AllOps()) {
......
......@@ -57,7 +57,7 @@ class Executor {
void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_vars = true,
bool create_local_scope = true, bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
......@@ -76,6 +76,7 @@ class Executor {
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets,
std::map<std::string, LoDTensor*>* fetch_targets,
bool create_local_scope = true,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
......
......@@ -208,10 +208,10 @@ void TestInference(const std::string& dirname,
if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0);
executor.RunPreparedContext(ctx.get(), scope, &feed_targets,
&fetch_targets, CreateVars);
&fetch_targets, true, CreateVars);
} else {
executor.Run(*inference_program, scope, &feed_targets, &fetch_targets,
CreateVars);
true, CreateVars);
}
// Enable the profiler
......
......@@ -184,7 +184,7 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* local_scope = &scope_->NewScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, scope_, false, false);
executor_->RunPreparedContext(prefetch_ctx_, scope_);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply);
......
......@@ -57,8 +57,7 @@ static void ParallelExecuteBlocks(
framework::Async([&executor, &prepared, &program, &scope, idx]() {
int run_block = idx; // thread local
try {
executor->RunPreparedContext(prepared[run_block].get(), scope,
false, false);
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
......@@ -211,8 +210,8 @@ static void AsyncUpdateThread(
}
auto fs = framework::Async([var_name, &executor, &v, prepared] {
try {
executor->RunPreparedContext(prepared, v.second->GetMutableLocalScope(),
false, false);
executor->RunPreparedContext(prepared,
v.second->GetMutableLocalScope());
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
......
......@@ -92,12 +92,16 @@ void InitSelectedRowsInScope(const p::CPUPlace &place, f::Scope *scope) {
void AddOp(const std::string &type, const f::VariableNameMap &inputs,
const f::VariableNameMap &outputs, f::AttributeMap attrs,
f::BlockDesc *block) {
f::BlockDesc *block, bool is_sparse) {
// insert output
for (auto kv : outputs) {
for (auto v : kv.second) {
auto var = block->Var(v);
var->SetDataType(f::proto::VarType::FP32);
var->SetPersistable(true);
if (is_sparse) {
var->SetType(f::proto::VarType::SELECTED_ROWS);
}
}
}
......@@ -128,7 +132,8 @@ void StartServerNet(bool is_sparse, std::atomic<bool> *initialized) {
auto *optimize_block = program.AppendBlock(root_block);
auto *prefetch_block = program.AppendBlock(root_block);
// X for server side tensors, RX for received tensors, must be of same shape.
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block);
AddOp("sum", {{"X", {"x0", "x1"}}}, {{"Out", {"Out"}}}, {}, optimize_block,
is_sparse);
f::AttributeMap attrs;
attrs.insert({"endpoint", std::string("127.0.0.1:0")});
attrs.insert({"Fanin", 1});
......
......@@ -52,15 +52,18 @@ class TestSendOp(unittest.TestCase):
serv = layers.ListenAndServ(
"127.0.0.1:0", ["X"], optimizer_mode=False)
with serv.do():
out_var = main.global_block().create_var(
name="scale_0.tmp_0",
psersistable=True,
dtype="float32",
shape=[32, 32])
x = layers.data(
shape=[32, 32],
dtype='float32',
name="X",
append_batch_size=False)
fluid.initializer.Constant(value=1.0)(x, main.global_block())
o = layers.scale(x=x, scale=10.0)
main.global_block().create_var(
name=o.name, psersistable=False, dtype=o.dtype, shape=o.shape)
layers.scale(x=x, scale=10.0, out=out_var)
self.server_exe = fluid.Executor(place)
self.server_exe.run(main)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册