提交 371c53f8 编写于 作者: L Liu Yiqun

Add profiling event in feed, fetch and load op.

上级 253ba667
...@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType ConvOp::GetExpectedKernelType( framework::OpKernelType ConvOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) { if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
} }
#endif #endif
...@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
"input and filter data type should be consistent"); "input and filter data type should be consistent");
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used"); "float16 can only be used when CUDNN is used");
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::DataLayout layout = framework::StringToDataLayout(data_format);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library_); library);
} }
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase { ...@@ -28,6 +29,10 @@ class FeedOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
// get device context from pool
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto feed_var_name = Input("X"); auto feed_var_name = Input("X");
auto *feed_var = scope.FindVar(feed_var_name); auto *feed_var = scope.FindVar(feed_var_name);
...@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase { ...@@ -50,14 +55,10 @@ class FeedOp : public framework::OperatorBase {
auto &feed_item = feed_list.at(static_cast<size_t>(col)); auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>(); auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
if (platform::is_same_place(feed_item.place(), place)) { if (platform::is_same_place(feed_item.place(), place)) {
out_item->ShareDataWith(feed_item); out_item->ShareDataWith(feed_item);
} else { } else {
framework::TensorCopy(feed_item, place, dev_ctx, out_item); framework::TensorCopy(feed_item, place, *dev_ctx, out_item);
} }
out_item->set_lod(feed_item.lod()); out_item->set_lod(feed_item.lod());
} }
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase { ...@@ -29,6 +30,9 @@ class FetchOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::RecordEvent record_event(Type(), pool.Get(place));
auto fetch_var_name = Input("X"); auto fetch_var_name = Input("X");
auto *fetch_var = scope.FindVar(fetch_var_name); auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr, PADDLE_ENFORCE(fetch_var != nullptr,
...@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase { ...@@ -53,7 +57,6 @@ class FetchOp : public framework::OperatorBase {
// FIXME(yuyang18): Should we assume the fetch operator always generate // FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs? // CPU outputs?
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(src_item.place()); auto &dev_ctx = *pool.Get(src_item.place());
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item); TensorCopy(src_item, platform::CPUPlace(), dev_ctx, &dst_item);
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase { ...@@ -29,6 +30,9 @@ class LoadOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx);
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
...@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -41,9 +45,7 @@ class LoadOp : public framework::OperatorBase {
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); auto *tensor = out_var->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); DeserializeFromStream(fin, tensor, *dev_ctx);
auto &dev_ctx = *pool.Get(place);
DeserializeFromStream(fin, tensor, dev_ctx);
if (platform::is_gpu_place(place)) { if (platform::is_gpu_place(place)) {
// copy CPU to GPU // copy CPU to GPU
...@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase { ...@@ -55,7 +57,7 @@ class LoadOp : public framework::OperatorBase {
out_var->Clear(); out_var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = out_var->GetMutable<framework::LoDTensor>();
tensor->set_lod(cpu_tensor.lod()); tensor->set_lod(cpu_tensor.lod());
TensorCopy(cpu_tensor, place, dev_ctx, tensor); TensorCopy(cpu_tensor, place, *dev_ctx, tensor);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册