提交 ce08dc87 编写于 作者: T typhoonzero

have stream removed error

上级 0bf799a5
...@@ -57,8 +57,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -57,8 +57,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p, const platform::Place &p,
const size_t &i) const { const size_t &i,
bool create_output) const {
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames(); auto var_names = op->InputArgumentNames();
...@@ -66,10 +69,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, ...@@ -66,10 +69,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var); op_handle->AddInput(var);
} }
var_names = op->OutputArgumentNames(); if (create_output) {
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i); CreateOpOutput(result, op_handle, each_var_name, p, i);
}
} }
} }
...@@ -100,9 +105,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -100,9 +105,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (!is_forwarding && op->Type() == "send") { if (!is_forwarding && op->Type() == "send") {
auto &p = places_[0]; auto &p = places_[0];
auto *s = local_scopes_[0]; auto *s = local_scopes_[0];
size_t i = 0; // FIXME(wuyi): send op always copy from GPU 0
result.ops_.emplace_back(new SendOpHandle(*op, s, p)); result.ops_.emplace_back(new SendOpHandle(*op, s));
CreateOpHandleIOs(&result, op, p, i); // Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(&result, op, p, 0, false);
continue; continue;
} }
...@@ -112,23 +119,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -112,23 +119,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p)); result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get(); auto *op_handle = result.ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
CreateOpHandleIOs(&result, op, p, i); CreateOpHandleIOs(&result, op, p, i);
// auto var_names = op->InputArgumentNames();
// for (auto &each_var_name : var_names) {
// VarHandle *var =
// CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
// op_handle->AddInput(var);
// }
auto var_names = op->OutputArgumentNames(); auto var_names = op->OutputArgumentNames();
// for (auto &each_var_name : var_names) {
// CreateOpOutput(&result, op_handle, each_var_name, p, i);
// }
if (is_forwarding) { if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) { if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
......
...@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const; const size_t &i, bool create_output = true) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
......
...@@ -19,11 +19,9 @@ namespace framework { ...@@ -19,11 +19,9 @@ namespace framework {
namespace details { namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const Scope *local_scope)
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope), local_scope_(local_scope) {}
place_(place) {}
void SendOpHandle::RunImpl() { void SendOpHandle::RunImpl() {
// Wait input done // Wait input done
...@@ -31,8 +29,8 @@ void SendOpHandle::RunImpl() { ...@@ -31,8 +29,8 @@ void SendOpHandle::RunImpl() {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]); in->generated_op_->Wait(dev_ctxes_[p]);
} }
platform::CPUPlace cpu;
op_->Run(*local_scope_, place_); op_->Run(*local_scope_, cpu);
} }
std::string SendOpHandle::Name() const { return "send"; } std::string SendOpHandle::Name() const { return "send"; }
......
...@@ -31,10 +31,8 @@ namespace details { ...@@ -31,10 +31,8 @@ namespace details {
struct SendOpHandle : public OpHandleBase { struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_; const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope);
const platform::Place& place);
std::string Name() const override; std::string Name() const override;
......
...@@ -255,6 +255,7 @@ class DistributeTranspiler: ...@@ -255,6 +255,7 @@ class DistributeTranspiler:
def get_trainer_program(self): def get_trainer_program(self):
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops) self.program.global_block().delete_ops(self.optimize_ops)
self.program.sync_with_cpp()
# FIXME(typhoonzero): serialize once will fix error occurs when clone. # FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__() self.program.__str__()
return self.program return self.program
......
...@@ -101,7 +101,9 @@ class ParallelExecutor(object): ...@@ -101,7 +101,9 @@ class ParallelExecutor(object):
self.persistable_vars = [ self.persistable_vars = [
v.name v.name
for v in filter(lambda var: var.persistable, main.list_vars()) for v in filter(lambda var: \
var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars())
] ]
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册