提交 16a9dfe4 编写于 作者: T typhoonzero

finish

上级 ec697681
...@@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -57,8 +57,7 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
const platform::Place &p, const platform::Place &p,
const size_t &i, const size_t &i) const {
bool create_output) const {
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>( op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
...@@ -69,13 +68,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op, ...@@ -69,13 +68,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var); op_handle->AddInput(var);
} }
if (create_output) {
var_names = op->OutputArgumentNames(); var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i); CreateOpOutput(result, op_handle, each_var_name, p, i);
} }
}
} }
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...@@ -106,10 +104,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -106,10 +104,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto &p = places_[0]; auto &p = places_[0];
auto *s = local_scopes_[0]; auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0 // FIXME(wuyi): send op always copy from GPU 0
result.ops_.emplace_back(new SendOpHandle(*op, s)); result.ops_.emplace_back(new SendOpHandle(*op, s, p));
// Create inputs for output on original place and no ssa output // Create inputs for output on original place and no ssa output
// is created for send op. // is created for send op.
CreateOpHandleIOs(&result, op, p, 0, false); CreateOpHandleIOs(&result, op, p, 0);
continue; continue;
} }
......
...@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -46,7 +46,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p, void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
const size_t &i, bool create_output = true) const; const size_t &i) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
......
...@@ -19,18 +19,22 @@ namespace framework { ...@@ -19,18 +19,22 @@ namespace framework {
namespace details { namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope) const Scope *local_scope,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope) {} local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() { void SendOpHandle::RunImpl() {
// Wait input done // Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
in->generated_op_->Wait(dev_ctxes_[p]); in->generated_op_->Wait(dev_ctxes_[p]);
} }
platform::CPUPlace cpu; op_->Run(*local_scope_, place_);
op_->Run(*local_scope_, cpu);
} }
std::string SendOpHandle::Name() const { return "send"; } std::string SendOpHandle::Name() const { return "send"; }
......
...@@ -31,8 +31,10 @@ namespace details { ...@@ -31,8 +31,10 @@ namespace details {
struct SendOpHandle : public OpHandleBase { struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_; const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope); SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
std::string Name() const override; std::string Name() const override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册