提交 ce08dc87 编写于 作者: T typhoonzero

have stream removed error

上级 0bf799a5
......@@ -57,8 +57,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p,
const size_t &i) const {
const size_t &i,
bool create_output) const {
auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
......@@ -66,10 +69,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
if (create_output) {
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
}
}
}
......@@ -100,9 +105,11 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
if (!is_forwarding && op->Type() == "send") {
auto &p = places_[0];
auto *s = local_scopes_[0];
size_t i = 0;
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
CreateOpHandleIOs(&result, op, p, i);
// FIXME(wuyi): send op always copy from GPU 0
result.ops_.emplace_back(new SendOpHandle(*op, s));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(&result, op, p, 0, false);
continue;
}
......@@ -112,23 +119,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
CreateOpHandleIOs(&result, op, p, i);
// auto var_names = op->InputArgumentNames();
// for (auto &each_var_name : var_names) {
// VarHandle *var =
// CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
// op_handle->AddInput(var);
// }
auto var_names = op->OutputArgumentNames();
// for (auto &each_var_name : var_names) {
// CreateOpOutput(&result, op_handle, each_var_name, p, i);
// }
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
// Insert ScaleCost OpHandle
......
......@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i) const;
const size_t &i, bool create_output = true) const;
private:
std::string loss_var_name_;
......
......@@ -19,11 +19,9 @@ namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
const Scope *local_scope)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
local_scope_(local_scope) {}
void SendOpHandle::RunImpl() {
// Wait input done
......@@ -31,8 +29,8 @@ void SendOpHandle::RunImpl() {
auto &p = static_cast<VarHandle *>(in)->place_;
in->generated_op_->Wait(dev_ctxes_[p]);
}
op_->Run(*local_scope_, place_);
platform::CPUPlace cpu;
op_->Run(*local_scope_, cpu);
}
std::string SendOpHandle::Name() const { return "send"; }
......
......@@ -31,10 +31,8 @@ namespace details {
struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope);
std::string Name() const override;
......
......@@ -255,6 +255,7 @@ class DistributeTranspiler:
def get_trainer_program(self):
# remove optimize ops and add a send op to main_program
self.program.global_block().delete_ops(self.optimize_ops)
self.program.sync_with_cpp()
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
self.program.__str__()
return self.program
......
......@@ -101,7 +101,9 @@ class ParallelExecutor(object):
self.persistable_vars = [
v.name
for v in filter(lambda var: var.persistable, main.list_vars())
for v in filter(lambda var: \
var.persistable and var.type != core.VarDesc.VarType.RAW,
main.list_vars())
]
self.executor = core.ParallelExecutor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册