提交 ac0e3828 编写于 作者: Y Yang Yang

test text

上级 20725f2d
...@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <set>
#include <vector> #include <vector>
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include <boost/range/adaptor/reversed.hpp>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -64,26 +68,94 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { ...@@ -64,26 +68,94 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
scope->NewVar(var.name()); scope->NewVar(var.name());
} }
for (auto& op_desc : block.ops()) { std::vector<bool> should_run = Preprocess(pdesc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); PADDLE_ENFORCE(should_run.size() == block.ops_size(),
"should_run.size() != block.ops_size()");
for (int i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i));
std::cout << op->DebugString() << std::endl; std::cout << op->DebugString() << std::endl;
op->Run(*scope, *device); op->Run(*scope, *device);
} }
// TODO(tonyyang-svail): need to test gpu device
for (auto& device_context : device_contexts_) {
device_context->Wait();
} }
// // print tensor value // // print tensor value
for (auto& var : block.vars()) { // for (auto& var : block.vars()) {
std::cout << var.name() << std::endl; // std::cout << var.name() << std::endl;
auto v = scope->FindVar(var.name()); // auto v = scope->FindVar(var.name());
const LoDTensor& t = v->Get<LoDTensor>(); // const LoDTensor& t = v->Get<LoDTensor>();
for (int i = 0; i < t.numel(); ++i) { // for (int i = 0; i < t.numel(); ++i) {
std::cout << t.data<float>()[i] << " "; // std::cout << t.data<float>()[i] << " ";
// }
// std::cout << std::endl;
// }
}
std::vector<bool> Executor::Preprocess(const ProgramDesc& pdesc) {
// TODO(tonyyang-svail):
// - only runs the first block
auto& block = pdesc.blocks(0);
auto& ops = block.ops();
bool expect_feed = true;
for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != "feed" || expect_feed,
"All FeedOps are at the beginning of the ProgramDesc");
expect_feed = (op_desc.type() == "feed");
} }
std::cout << std::endl;
bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != "fetch" || expect_fetch,
"All FetchOps must at the end of the ProgramDesc");
expect_fetch = (op_desc.type() == "fetch");
} }
std::set<std::string> dependent_vars;
std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
bool found_dependent_vars = false;
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
if (dependent_vars.count(argu) != 0) {
found_dependent_vars = true;
}
}
}
// TODO(tonyyang-svail): add VLOG here for debugging
if (op_desc.type() == "fetch" || found_dependent_vars) {
// erase its output to the dependency graph
for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) {
dependent_vars.erase(argu);
}
}
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
dependent_vars.insert(argu);
}
}
// this op should be executed
should_run.push_back(true);
} else {
// this op should NOT be executed
should_run.push_back(false);
}
}
// since we are traversing the ProgramDesc in reverse order
// we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end());
return should_run;
} }
} // namespace framework } // namespace framework
......
...@@ -26,8 +26,24 @@ class Executor { ...@@ -26,8 +26,24 @@ class Executor {
public: public:
explicit Executor(const std::vector<platform::Place>& places); explicit Executor(const std::vector<platform::Place>& places);
~Executor(); ~Executor();
/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
*
* @param
* ProgramDesc
* Scope
*/
void Run(const ProgramDesc&, Scope*); void Run(const ProgramDesc&, Scope*);
protected:
/* @Brief
*
* @param
* ProgramDesc
*/
std::vector<bool> Preprocess(const ProgramDesc& pdesc);
private: private:
std::vector<platform::DeviceContext*> device_contexts_; std::vector<platform::DeviceContext*> device_contexts_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册