提交 16a9dfe4 编写于 作者: T typhoonzero

finish

上级 ec697681
......@@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p,
const size_t &i,
bool create_output) const {
const size_t &i) const {
auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
......@@ -69,13 +68,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var);
}
if (create_output) {
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
}
}
}
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
......@@ -106,10 +104,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto &p = places_[0];
auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0
result.ops_.emplace_back(new SendOpHandle(*op, s));
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
CreateOpHandleIOs(&result, op, p, 0, false);
CreateOpHandleIOs(&result, op, p, 0);
continue;
}
......
......@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i, bool create_output = true) const;
const size_t &i) const;
private:
std::string loss_var_name_;
......
......@@ -19,18 +19,22 @@ namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope)
const Scope *local_scope,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope) {}
local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]);
}
platform::CPUPlace cpu;
op_->Run(*local_scope_, cpu);
op_->Run(*local_scope_, place_);
}
std::string SendOpHandle::Name() const { return "send"; }
......
......@@ -31,8 +31,10 @@ namespace details {
struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope);
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
std::string Name() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册