未验证 提交 c0421379 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #9043 from Xreki/core_inference_remove_clone

Remove unnecessary clone of program in C++ Executor.Run
......@@ -113,10 +113,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets,
const BlockDesc& block,
std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) {
size_t feed_count = 0;
for (auto* op : block->AllOps()) {
for (auto* op : block.AllOps()) {
if (op->Type() == kFeedOpType) {
feed_count++;
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
......@@ -135,7 +136,7 @@ static bool has_feed_operators(
"The number of feed operators should match 'feed_targets'");
// When feed operator are present, so should be feed_holder
auto var = block->FindVar(feed_holder_name);
auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
......@@ -153,10 +154,10 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets,
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) {
size_t fetch_count = 0;
for (auto* op : block->AllOps()) {
for (auto* op : block.AllOps()) {
if (op->Type() == kFetchOpType) {
fetch_count++;
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
......@@ -175,7 +176,7 @@ static bool has_fetch_operators(
"The number of fetch operators should match 'fetch_targets'");
// When fetch operator are present, so should be fetch_holder
auto var = block->FindVar(fetch_holder_name);
auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
......@@ -192,10 +193,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
auto* copy_program = new ProgramDesc(program);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
bool has_fetch_ops =
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);
ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
if (!has_feed_ops || !has_fetch_ops) {
copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
}
auto* global_block = copy_program->MutableBlock(0);
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
if (!has_feed_ops) {
// create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
......@@ -228,7 +238,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
if (!has_fetch_ops) {
// create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
......@@ -262,8 +272,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
delete copy_program;
}
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
......@@ -313,9 +321,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} // if (create_vars)
for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
......@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
library = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
}
#endif
......@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"input and filter data type should be consistent");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
framework::DataLayout layout = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);
......@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase {
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item);
} else {
framework::TensorCopy(feed_item, place, dev_ctx, out_item);
framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
}
out_item->set_lod(feed_item.lod());
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
......@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(src_item.place());
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
......@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
DeserializeFromStream(fin, tensor, dev_ctx);
DeserializeFromStream(fin, tensor, *dev_ctx);
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
......@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase {
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, dev_ctx, tensor);
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册