提交 371c53f8 编写于 作者: L Liu Yiqun

Add profiling event in feed, fetch and load op.

上级 253ba667
......@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
library = framework::LibraryType::kCUDNN;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
library = framework::LibraryType::kMKLDNN;
}
#endif
......@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"input and filter data type should be consistent");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
}
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
library_);
framework::DataLayout layout = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
// get device context from pool
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name);
......@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase {
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item);
} else {
framework::TensorCopy(feed_item, place, dev_ctx, out_item);
framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
}
out_item->set_lod(feed_item.lod());
}
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
......@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(src_item.place());
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
......@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase {
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
......@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
DeserializeFromStream(fin, tensor, dev_ctx);
DeserializeFromStream(fin, tensor, *dev_ctx);
if (platform::is_gpu_place(place)) {
// copy CPU to GPU
......@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase {
out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, dev_ctx, tensor);
TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册