“c02cdbf60b51b8d98a49185535f5d527a2965142”上不存在“arch/x86/events/perf_event.h”
提交 e5155713 编写于 作者: Y Yang Yang

clean up for review

上级 089cc11d
...@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <set> #include <set>
#include <vector> #include <vector>
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -27,7 +29,11 @@ limitations under the License. */ ...@@ -27,7 +29,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
Executor::Executor(const std::vector<platform::Place>& places) { Executor::Executor(const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
device_contexts_.resize(places.size()); device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) { for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) { if (platform::is_cpu_place(places[i])) {
...@@ -46,9 +52,7 @@ Executor::Executor(const std::vector<platform::Place>& places) { ...@@ -46,9 +52,7 @@ Executor::Executor(const std::vector<platform::Place>& places) {
Executor::~Executor() { Executor::~Executor() {
for (auto& device_context : device_contexts_) { for (auto& device_context : device_contexts_) {
if (device_context) { delete device_context;
delete device_context;
}
} }
} }
...@@ -56,6 +60,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { ...@@ -56,6 +60,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs the first block (i.e. no RNN support) // - only runs the first block (i.e. no RNN support)
// - only runs on the first device (i.e. no interdevice communication) // - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_GT(pdesc.blocks_size(), 0);
auto& block = pdesc.blocks(0); auto& block = pdesc.blocks(0);
auto& device = device_contexts_[0]; auto& device = device_contexts_[0];
...@@ -66,12 +72,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { ...@@ -66,12 +72,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
Scope& local_scope = scope->NewScope(); Scope& local_scope = scope->NewScope();
std::vector<bool> should_run = Preprocess(pdesc); std::vector<bool> should_run = Prune(pdesc);
PADDLE_ENFORCE(should_run.size() == block.ops_size()); PADDLE_ENFORCE_EQ(should_run.size(), block.ops_size());
for (size_t i = 0; i < should_run.size(); ++i) { for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) { if (should_run[i]) {
for (auto var : block.ops(i).outputs()) { for (auto& var : block.ops(i).outputs()) {
for (auto argu : var.arguments()) { for (auto& argu : var.arguments()) {
if (local_scope.FindVar(argu) == nullptr) { if (local_scope.FindVar(argu) == nullptr) {
local_scope.NewVar(argu); local_scope.NewVar(argu);
} }
...@@ -81,28 +87,32 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { ...@@ -81,28 +87,32 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
op->Run(local_scope, *device); op->Run(local_scope, *device);
} }
} }
// TODO(tonyyang-svail):
// - Destroy local_scope
} }
std::vector<bool> Executor::Preprocess(const ProgramDesc& pdesc) { std::vector<bool> Executor::Prune(const ProgramDesc& pdesc) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs the first block // - only runs the first block
// - will change to use multiple blocks for RNN op and Cond Op
auto& block = pdesc.blocks(0); auto& block = pdesc.blocks(0);
auto& ops = block.ops(); auto& ops = block.ops();
bool expect_feed = true; bool expect_feed = true;
for (auto& op_desc : ops) { for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != "feed" || expect_feed, PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
"All FeedOps are at the beginning of the ProgramDesc"); "All FeedOps are at the beginning of the ProgramDesc");
expect_feed = (op_desc.type() == "feed"); expect_feed = (op_desc.type() == kFeedOpType);
} }
bool expect_fetch = true; bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter; auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != "fetch" || expect_fetch, PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
"All FetchOps must at the end of the ProgramDesc"); "All FetchOps must at the end of the ProgramDesc");
expect_fetch = (op_desc.type() == "fetch"); expect_fetch = (op_desc.type() == kFetchOpType);
} }
std::set<std::string> dependent_vars; std::set<std::string> dependent_vars;
...@@ -119,7 +129,7 @@ std::vector<bool> Executor::Preprocess(const ProgramDesc& pdesc) { ...@@ -119,7 +129,7 @@ std::vector<bool> Executor::Preprocess(const ProgramDesc& pdesc) {
} }
} }
if (op_desc.type() == "fetch" || found_dependent_vars) { if (op_desc.type() == kFetchOpType || found_dependent_vars) {
// erase its output to the dependency graph // erase its output to the dependency graph
for (auto& var : op_desc.outputs()) { for (auto& var : op_desc.outputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
...@@ -140,6 +150,10 @@ std::vector<bool> Executor::Preprocess(const ProgramDesc& pdesc) { ...@@ -140,6 +150,10 @@ std::vector<bool> Executor::Preprocess(const ProgramDesc& pdesc) {
} }
} }
// TODO(tonyyang-svail):
// - check this after integration of Init
// PADDLE_ENFORCE(dependent_vars.empty());
// since we are traversing the ProgramDesc in reverse order // since we are traversing the ProgramDesc in reverse order
// we reverse the should_run vector // we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end()); std::reverse(should_run.begin(), should_run.end());
......
...@@ -46,7 +46,7 @@ class Executor { ...@@ -46,7 +46,7 @@ class Executor {
* @return * @return
* vector<bool> Same size as ops. Indicates whether an op should be run. * vector<bool> Same size as ops. Indicates whether an op should be run.
*/ */
std::vector<bool> Preprocess(const ProgramDesc& pdesc); std::vector<bool> Prune(const ProgramDesc& pdesc);
private: private:
std::vector<platform::DeviceContext*> device_contexts_; std::vector<platform::DeviceContext*> device_contexts_;
......
...@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include <memory>
#include <vector> #include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
// #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
...@@ -34,9 +36,6 @@ using std::string; ...@@ -34,9 +36,6 @@ using std::string;
using namespace paddle::platform; using namespace paddle::platform;
using namespace paddle::framework; using namespace paddle::framework;
typedef paddle::framework::BlockDesc proto_block;
typedef paddle::framework::OpDesc proto_op;
void AddOp(const std::string& type, const VariableNameMap& inputs, void AddOp(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, AttributeMap attrs, const VariableNameMap& outputs, AttributeMap attrs,
paddle::framework::BlockDescBind* block) { paddle::framework::BlockDescBind* block) {
...@@ -51,10 +50,10 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, ...@@ -51,10 +50,10 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
// insert op // insert op
auto op = block->AppendOp(); auto op = block->AppendOp();
op->SetType(type); op->SetType(type);
for (auto kv : inputs) { for (auto& kv : inputs) {
op->SetInput(kv.first, kv.second); op->SetInput(kv.first, kv.second);
} }
for (auto kv : outputs) { for (auto& kv : outputs) {
op->SetOutput(kv.first, kv.second); op->SetOutput(kv.first, kv.second);
} }
op->SetAttrMap(attrs); op->SetAttrMap(attrs);
...@@ -65,11 +64,11 @@ std::once_flag set_variable_flag; ...@@ -65,11 +64,11 @@ std::once_flag set_variable_flag;
// Tensors in feed value variable will only be in CPUPlace // Tensors in feed value variable will only be in CPUPlace
// So we can memcpy the data from vector<T> to feed_value // So we can memcpy the data from vector<T> to feed_value
template <typename T> template <typename T>
void set_feed_variable(const std::vector<std::vector<T>>& inputs) { void SetFeedVariable(const std::vector<std::vector<T>>& inputs) {
typedef std::vector<paddle::framework::Tensor> FeedInputs; typedef std::vector<paddle::framework::Tensor> FeedInputs;
Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value");
FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>()); FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>());
auto size = inputs.size(); size_t size = inputs.size();
feed_inputs.resize(size); feed_inputs.resize(size);
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
T* dst = feed_inputs[i].mutable_data<T>( T* dst = feed_inputs[i].mutable_data<T>(
...@@ -81,12 +80,12 @@ void set_feed_variable(const std::vector<std::vector<T>>& inputs) { ...@@ -81,12 +80,12 @@ void set_feed_variable(const std::vector<std::vector<T>>& inputs) {
// Tensors in fetch value variable will only be in CPUPlace // Tensors in fetch value variable will only be in CPUPlace
// So we can memcpy the data from fetch_value to vector<T> // So we can memcpy the data from fetch_value to vector<T>
template <typename T> template <typename T>
std::vector<std::vector<T>> get_fetch_variable() { std::vector<std::vector<T>> GetFetchVariable() {
typedef std::vector<paddle::framework::Tensor> FetchOutputs; typedef std::vector<paddle::framework::Tensor> FetchOutputs;
Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value");
FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>()); FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>());
auto size = fetch_outputs.size(); size_t size = fetch_outputs.size();
std::vector<std::vector<T>> result; std::vector<std::vector<T>> result;
result.reserve(size); result.reserve(size);
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
...@@ -105,7 +104,7 @@ class ExecutorTesterRandom : public ::testing::Test { ...@@ -105,7 +104,7 @@ class ExecutorTesterRandom : public ::testing::Test {
virtual void SetUp() override { virtual void SetUp() override {
int input_dim = 5, batch_size = 2, embed_dim = 5; int input_dim = 5, batch_size = 2, embed_dim = 5;
// init pdesc ----------------------------------------- // init pdesc
auto temp_init_root_block = init_pdesc_.add_blocks(); auto temp_init_root_block = init_pdesc_.add_blocks();
temp_init_root_block->set_idx(0); temp_init_root_block->set_idx(0);
temp_init_root_block->set_parent_idx(-1); temp_init_root_block->set_parent_idx(-1);
...@@ -128,7 +127,7 @@ class ExecutorTesterRandom : public ::testing::Test { ...@@ -128,7 +127,7 @@ class ExecutorTesterRandom : public ::testing::Test {
// flush // flush
init_program.Proto(); init_program.Proto();
// run pdesc ----------------------------------------- // run pdesc
auto temp_root_block = pdesc_.add_blocks(); auto temp_root_block = pdesc_.add_blocks();
temp_root_block->set_idx(0); temp_root_block->set_idx(0);
temp_root_block->set_parent_idx(-1); temp_root_block->set_parent_idx(-1);
...@@ -154,9 +153,6 @@ class ExecutorTesterRandom : public ::testing::Test { ...@@ -154,9 +153,6 @@ class ExecutorTesterRandom : public ::testing::Test {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - Test with Backward // - Test with Backward
// AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}},
// {{"dims", std::vector<int>{batch_size, 1}}}, root_block);
// AppendBackward(program, {});
} }
protected: protected:
...@@ -213,12 +209,11 @@ TEST_F(ExecutorTesterRandom, CPU) { ...@@ -213,12 +209,11 @@ TEST_F(ExecutorTesterRandom, CPU) {
// "pointer being freed was not allocated" error will appear. // "pointer being freed was not allocated" error will appear.
paddle::memory::Used(cpu_place); paddle::memory::Used(cpu_place);
Executor* executor = new Executor(places); std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope()); executor->Run(init_pdesc_, GetGlobalScope());
executor->Run(pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
delete executor;
} }
TEST_F(ExecutorTesterFeedAndFetch, CPU) { TEST_F(ExecutorTesterFeedAndFetch, CPU) {
...@@ -232,13 +227,12 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { ...@@ -232,13 +227,12 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) {
// "pointer being freed was not allocated" error will appear. // "pointer being freed was not allocated" error will appear.
paddle::memory::Used(cpu_place); paddle::memory::Used(cpu_place);
Executor* executor = new Executor(places); std::unique_ptr<Executor> executor(new Executor(places));
// 3 mini-batch for (int batch_id = 0; batch_id < 3; batch_id++) {
for (int i = 0; i < 3; i++) { SetFeedVariable<float>(inputs_);
set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) { for (size_t i = 0; i < result.size(); ++i) {
PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size());
...@@ -247,8 +241,6 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { ...@@ -247,8 +241,6 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) {
} }
} }
} }
delete executor;
} }
#else #else
TEST_F(ExecutorTesterRandom, GPU) { TEST_F(ExecutorTesterRandom, GPU) {
...@@ -265,13 +257,11 @@ TEST_F(ExecutorTesterRandom, GPU) { ...@@ -265,13 +257,11 @@ TEST_F(ExecutorTesterRandom, GPU) {
paddle::memory::Used(CPUPlace()); paddle::memory::Used(CPUPlace());
paddle::memory::Used(gpu_place); paddle::memory::Used(gpu_place);
Executor* executor = new Executor(places); std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope()); executor->Run(init_pdesc_, GetGlobalScope());
executor->Run(pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
delete executor;
} }
TEST_F(ExecutorTesterFeedAndFetch, GPU) { TEST_F(ExecutorTesterFeedAndFetch, GPU) {
...@@ -287,13 +277,12 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { ...@@ -287,13 +277,12 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) {
paddle::memory::Used(CPUPlace()); paddle::memory::Used(CPUPlace());
paddle::memory::Used(gpu_place); paddle::memory::Used(gpu_place);
Executor* executor = new Executor(places); std::unique_ptr<Executor> executor(new Executor(places));
// 3 mini-batch for (int batch_id = 0; batch_id < 3; batch_id++) {
for (int i = 0; i < 3; i++) { SetFeedVariable<float>(inputs_);
set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetGlobalScope()); executor->Run(pdesc_, GetGlobalScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) { for (size_t i = 0; i < result.size(); ++i) {
PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size());
...@@ -302,6 +291,5 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { ...@@ -302,6 +291,5 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) {
} }
} }
} }
delete executor;
} }
#endif #endif
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include <memory> // for unique_ptr #include <memory> // for unique_ptr
#include <mutex> // for call_once #include <mutex> // for call_once
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
......
...@@ -31,6 +31,7 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -31,6 +31,7 @@ class FeedOp : public framework::OperatorWithKernel {
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>(); const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>();
PADDLE_ENFORCE_GT(tensors.size(), col);
auto in_dim = tensors[col].dims(); auto in_dim = tensors[col].dims();
ctx->SetOutputDim("Out", in_dim); ctx->SetOutputDim("Out", in_dim);
// TODO(qijun): need to handle LodTensor later // TODO(qijun): need to handle LodTensor later
......
...@@ -35,6 +35,7 @@ class FetchOp : public framework::OperatorWithKernel { ...@@ -35,6 +35,7 @@ class FetchOp : public framework::OperatorWithKernel {
} }
auto input_dim = ctx->GetInputDim("Input"); auto input_dim = ctx->GetInputDim("Input");
PADDLE_ENFORCE_GT(tensors->size(), col);
(*tensors)[col].Resize(input_dim); (*tensors)[col].Resize(input_dim);
// TODO(qijun): need to handle LodTensor later // TODO(qijun): need to handle LodTensor later
......
...@@ -44,7 +44,7 @@ int GetCurrentDeviceId() { ...@@ -44,7 +44,7 @@ int GetCurrentDeviceId() {
void SetDeviceId(int id) { void SetDeviceId(int id) {
// TODO(qijun): find a better way to cache the cuda device count // TODO(qijun): find a better way to cache the cuda device count
PADDLE_ENFORCE(id < GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
PADDLE_ENFORCE(cudaSetDevice(id), PADDLE_ENFORCE(cudaSetDevice(id),
"cudaSetDevice failed in paddle::platform::SetDeviceId"); "cudaSetDevice failed in paddle::platform::SetDeviceId");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册