You need to sign in or sign up before continuing.
未验证 提交 c0421379 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #9043 from Xreki/core_inference_remove_clone

Remove unnecessary clone of program in C++ Executor.Run
...@@ -113,10 +113,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -113,10 +113,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// and feed_holder_name. Raise exception when any mismatch is found. // and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info. // Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators( static bool has_feed_operators(
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets, const BlockDesc& block,
std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) { const std::string& feed_holder_name) {
size_t feed_count = 0; size_t feed_count = 0;
for (auto* op : block->AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
feed_count++; feed_count++;
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
...@@ -135,7 +136,7 @@ static bool has_feed_operators( ...@@ -135,7 +136,7 @@ static bool has_feed_operators(
"The number of feed operators should match 'feed_targets'"); "The number of feed operators should match 'feed_targets'");
// When feed operator are present, so should be feed_holder // When feed operator are present, so should be feed_holder
auto var = block->FindVar(feed_holder_name); auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name); feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
...@@ -153,10 +154,10 @@ static bool has_feed_operators( ...@@ -153,10 +154,10 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found. // and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info. // Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators( static bool has_fetch_operators(
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets, const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
size_t fetch_count = 0; size_t fetch_count = 0;
for (auto* op : block->AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
fetch_count++; fetch_count++;
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
...@@ -175,7 +176,7 @@ static bool has_fetch_operators( ...@@ -175,7 +176,7 @@ static bool has_fetch_operators(
"The number of fetch operators should match 'fetch_targets'"); "The number of fetch operators should match 'fetch_targets'");
// When fetch operator are present, so should be fetch_holder // When fetch operator are present, so should be fetch_holder
auto var = block->FindVar(fetch_holder_name); auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name); fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
...@@ -192,10 +193,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -192,10 +193,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
const std::string& feed_holder_name, const std::string& feed_holder_name,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId); platform::RecordBlock b(kProgramId);
auto* copy_program = new ProgramDesc(program); bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
bool has_fetch_ops =
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);
ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
if (!has_feed_ops || !has_fetch_ops) {
copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
}
auto* global_block = copy_program->MutableBlock(0); auto* global_block = copy_program->MutableBlock(0);
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { if (!has_feed_ops) {
// create feed_holder variable // create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name); auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH); feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
...@@ -228,7 +238,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -228,7 +238,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { if (!has_fetch_ops) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarType::FETCH_LIST); fetch_holder->SetType(proto::VarType::FETCH_LIST);
...@@ -262,8 +272,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -262,8 +272,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
GetFetchVariable(*scope, fetch_holder_name, idx); GetFetchVariable(*scope, fetch_holder_name, idx);
} }
} }
delete copy_program;
} }
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
...@@ -313,9 +321,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -313,9 +321,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} // if (create_vars) } // if (create_vars)
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
...@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
} }
#endif #endif
...@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"input and filter data type should be consistent"); "input and filter data type should be consistent");
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used"); "float16 can only be used when CUDNN is used");
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_); library);
} }
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase { ...@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
// get device context from pool
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto feed_var_name = Input("X"); auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name); auto *feed_var = scope.FindVar(feed_var_name);
...@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase { ...@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase {
auto &feed_item = feed_list.at(static_cast<size_t>(col)); auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>(); auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
if (platform::is_same_place(feed_item.place(), place)) { if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item); out_item->ShareDataWith(feed_item);
} else { } else {
framework::TensorCopy(feed_item, place, dev_ctx, out_item); framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
} }
out_item->set_lod(feed_item.lod()); out_item->set_lod(feed_item.lod());
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase { ...@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
auto fetch_var_name = Input("X"); auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name); auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr, PADDLE_ENFORCE(fetch_var != nullptr,
...@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase { ...@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(src_item.place()); auto &dev_ctx = *pool.Get(src_item.place());
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item); TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase { ...@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
...@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); DeserializeFromStream(fin, tensor, *dev_ctx);
auto &dev_ctx = *pool.Get(place);
DeserializeFromStream(fin, tensor, dev_ctx);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
// copy CPU to GPU // copy CPU to GPU
...@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase {
out_var->Clear(); out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod()); tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, dev_ctx, tensor); TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册