diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index aafef12554fd0877ead14ed67a17db03a8c89eff..89b83f82fb0e8412bcb7fe2ac1229cca19788172 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -13,13 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/executor.h" +#include #include #include +#include #include #include "paddle/framework/lod_tensor.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" +#include + namespace paddle { namespace framework { @@ -64,26 +68,94 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { scope->NewVar(var.name()); } - for (auto& op_desc : block.ops()) { - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); - std::cout << op->DebugString() << std::endl; - op->Run(*scope, *device); + std::vector should_run = Preprocess(pdesc); + PADDLE_ENFORCE(should_run.size() == block.ops_size(), + "should_run.size() != block.ops_size()"); + for (int i = 0; i < should_run.size(); ++i) { + if (should_run[i]) { + auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i)); + std::cout << op->DebugString() << std::endl; + op->Run(*scope, *device); + } } - // TODO(tonyyang-svail): need to test gpu device - for (auto& device_context : device_contexts_) { - device_context->Wait(); - } // // print tensor value - for (auto& var : block.vars()) { - std::cout << var.name() << std::endl; - auto v = scope->FindVar(var.name()); - const LoDTensor& t = v->Get(); - for (int i = 0; i < t.numel(); ++i) { - std::cout << t.data()[i] << " "; + // for (auto& var : block.vars()) { + // std::cout << var.name() << std::endl; + // auto v = scope->FindVar(var.name()); + // const LoDTensor& t = v->Get(); + // for (int i = 0; i < t.numel(); ++i) { + // std::cout << t.data()[i] << " "; + // } + // std::cout << std::endl; + // } +} + +std::vector Executor::Preprocess(const ProgramDesc& pdesc) { + // TODO(tonyyang-svail): + // - only runs the first block + + auto& block = pdesc.blocks(0); + auto& ops = block.ops(); + + bool expect_feed = true; + for (auto& op_desc : ops) { + PADDLE_ENFORCE(op_desc.type() != "feed" || expect_feed, + "All FeedOps are at the beginning of the ProgramDesc"); + expect_feed = (op_desc.type() == "feed"); + } + + bool expect_fetch = true; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + PADDLE_ENFORCE(op_desc.type() != "fetch" || expect_fetch, + "All FetchOps must at the end of the ProgramDesc"); + expect_fetch = (op_desc.type() == "fetch"); + } + + std::set dependent_vars; + std::vector should_run; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + + bool found_dependent_vars = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + found_dependent_vars = true; + } + } + } + + // TODO(tonyyang-svail): add VLOG here for debugging + if (op_desc.type() == "fetch" || found_dependent_vars) { + // erase its output to the dependency graph + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.erase(argu); + } + } + + // insert its input to the dependency graph + for (auto& var : op_desc.inputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.insert(argu); + } + } + + // this op should be executed + should_run.push_back(true); + } else { + // this op should NOT be executed + should_run.push_back(false); } - std::cout << std::endl; } + + // since we are traversing the ProgramDesc in reverse order + // we reverse the should_run vector + std::reverse(should_run.begin(), should_run.end()); + + return should_run; } } // namespace framework diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index 9e443c8fca7ee16335c2192a38771a5473ea9932..1d2e6c96ded0ba83c31cfc31b335b8be6acdd5ed 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -26,8 +26,24 @@ class Executor { public: explicit Executor(const std::vector& places); ~Executor(); + + /* @Brief + * Runtime evaluation of the given ProgramDesc under certain Scope + * + * @param + * ProgramDesc + * Scope + */ void Run(const ProgramDesc&, Scope*); + protected: + /* @Brief + * + * @param + * ProgramDesc + */ + std::vector Preprocess(const ProgramDesc& pdesc); + private: std::vector device_contexts_; };