未验证 提交 c56d6978 编写于 作者: W wanghuancoder 提交者: GitHub

modify fetch logic, use D2H Stream (#35191)

* modify fetch logic, use D2H Stream, test=develop

* refine, test=develop
上级 7743cdf2
...@@ -143,8 +143,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place, ...@@ -143,8 +143,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
main_program_(main_prog), main_program_(main_prog),
global_scope_(global_scope), global_scope_(global_scope),
d2h_ctx_pool_({place}), d2h_ctx_pool_({place}),
h2d_ctx_pool_({place}), h2d_ctx_pool_({place}) {
fetch_context_pool_({place}) {
is_build_ = false; is_build_ = false;
garbages_.reset(new GarbageQueue()); garbages_.reset(new GarbageQueue());
...@@ -339,9 +338,6 @@ void InterpreterCore::BuildInstructionCtx(Instruction* instr_node, ...@@ -339,9 +338,6 @@ void InterpreterCore::BuildInstructionCtx(Instruction* instr_node,
new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get())); new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get()));
auto* dev_ctx = instr_node->dev_ctx_; auto* dev_ctx = instr_node->dev_ctx_;
if (instr_node->kernel_func_.operator_base_->Type() == "fetch_v2") {
dev_ctx = fetch_context_pool_.Get(place);
}
Scope scope; Scope scope;
instr_node->execution_ctx_.reset(new ExecutionContext( instr_node->execution_ctx_.reset(new ExecutionContext(
...@@ -356,12 +352,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { ...@@ -356,12 +352,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
instr_node.kernel_func_.operator_base_) instr_node.kernel_func_.operator_base_)
->InferShape(instr_node.infershape_ctx_.get()); ->InferShape(instr_node.infershape_ctx_.get());
if (instr_node.kernel_func_.operator_base_->Type() == "fetch_v2") {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place_);
dev_ctx->Wait(); // TODO(wanghuancoder)
}
instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get()); instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
} }
...@@ -411,8 +401,6 @@ void InterpreterCore::ExecuteInstructionList( ...@@ -411,8 +401,6 @@ void InterpreterCore::ExecuteInstructionList(
working_var_ref); working_var_ref);
} }
fetch_context_pool_.Get(place)->Wait();
for (size_t i = 0; i < working_var_ref.size(); ++i) { for (size_t i = 0; i < working_var_ref.size(); ++i) {
if (working_var_ref[i].var_ref_count_ != 0) { if (working_var_ref[i].var_ref_count_ != 0) {
std::cerr << " var ref is not zero " << i << std::endl; std::cerr << " var ref is not zero " << i << std::endl;
...@@ -671,6 +659,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place, ...@@ -671,6 +659,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
expected_kernel_key); expected_kernel_key);
if (!platform::is_same_place(kernel_type_for_var.place_, if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_)) { expected_kernel_key.place_)) {
if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false);
}
// need trans place // need trans place
// 1. add var in scope // 1. add var in scope
// 2. add copy op // 2. add copy op
......
...@@ -114,8 +114,6 @@ class InterpreterCore { ...@@ -114,8 +114,6 @@ class InterpreterCore {
size_t max_memory_size_; size_t max_memory_size_;
size_t cur_memory_size_; size_t cur_memory_size_;
std::unique_ptr<WorkQueue> gc_queue_; std::unique_ptr<WorkQueue> gc_queue_;
platform::DeviceContextPool fetch_context_pool_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -36,10 +36,9 @@ struct float16; ...@@ -36,10 +36,9 @@ struct float16;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static void DataCopy(const framework::LoDTensor &src_item, static void DeepCopy(const framework::LoDTensor &src_item,
const std::string &fetch_var_name, const std::string &fetch_var_name,
framework::LoDTensor *dst_item, framework::LoDTensor *dst_item) {
const platform::DeviceContext &dev_ctx) {
if (src_item.IsInitialized() && src_item.numel() > 0) { if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle // Conversion from MKL-DNN to Paddle
...@@ -53,26 +52,13 @@ static void DataCopy(const framework::LoDTensor &src_item, ...@@ -53,26 +52,13 @@ static void DataCopy(const framework::LoDTensor &src_item,
: paddle::platform::MKLDNNDeviceContext::tls() : paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout(), .get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace()); src_item, &out, platform::CPUPlace());
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, dst_item); TensorCopySync(out, platform::CPUPlace(), dst_item);
} else { } else {
if (platform::is_gpu_place(src_item.place())) { TensorCopySync(src_item, platform::CPUPlace(), dst_item);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
#endif
} else {
TensorCopy(src_item, platform::CPUPlace(), dst_item);
}
} }
#else #else
if (platform::is_gpu_place(src_item.place())) { TensorCopySync(src_item, platform::CPUPlace(), dst_item);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
#endif #endif
} else {
TensorCopy(src_item, platform::CPUPlace(), dst_item);
}
#endif
} else { } else {
// Not copy, if the src tensor is empty. // Not copy, if the src tensor is empty.
dst_item->clear(); dst_item->clear();
...@@ -92,15 +78,14 @@ class FetchV2Op : public framework::OperatorWithKernel { ...@@ -92,15 +78,14 @@ class FetchV2Op : public framework::OperatorWithKernel {
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_, tensor.place(), tensor.layout());
tensor.layout());
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context()); platform::CPUPlace());
} }
}; };
...@@ -119,12 +104,10 @@ class FetchV2Kernel { ...@@ -119,12 +104,10 @@ class FetchV2Kernel {
if (fetch_var == nullptr) { if (fetch_var == nullptr) {
return; return;
} }
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true, PADDLE_ENFORCE_EQ(
platform::errors::NotFound( ctx.HasOutput("Out"), true,
"Output(Out) of memcpy_d2h_op is not found.")); platform::errors::NotFound("Output(Out) of fetch_v2_op is not found."));
auto *out_var = ctx.OutputVar("Out"); auto *out_var = ctx.OutputVar("Out");
// Get dev_ctx from ExecutionContext, it's D2H stream
auto &dev_ctx = ctx.device_context();
int col = ctx.Attr<int>("col"); int col = ctx.Attr<int>("col");
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
...@@ -140,10 +123,19 @@ class FetchV2Kernel { ...@@ -140,10 +123,19 @@ class FetchV2Kernel {
fetch_list->resize(col + 1); fetch_list->resize(col + 1);
} }
bool deepcopy = ctx.Attr<bool>("deepcopy");
if (fetch_var->IsType<framework::LoDTensor>()) { if (fetch_var->IsType<framework::LoDTensor>()) {
auto &src_item = fetch_var->Get<framework::LoDTensor>(); auto &src_item = fetch_var->Get<framework::LoDTensor>();
auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col))); auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col)));
DataCopy(src_item, fetch_var_name, dst_item, dev_ctx); PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item.place()), true,
platform::errors::InvalidArgument(
"Tensor's place of input(X) must be CPUPlace."));
if (deepcopy) {
DeepCopy(src_item, fetch_var_name, dst_item);
} else {
dst_item->ShareDataWith(src_item);
}
} else { } else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>(); auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
framework::LoDTensorArray tmp(src_item.size()); framework::LoDTensorArray tmp(src_item.size());
...@@ -151,7 +143,14 @@ class FetchV2Kernel { ...@@ -151,7 +143,14 @@ class FetchV2Kernel {
auto &dst_item = auto &dst_item =
BOOST_GET(framework::LoDTensorArray, fetch_list->at(col)); BOOST_GET(framework::LoDTensorArray, fetch_list->at(col));
for (size_t i = 0; i < src_item.size(); ++i) { for (size_t i = 0; i < src_item.size(); ++i) {
DataCopy(src_item[i], fetch_var_name, &dst_item[i], dev_ctx); PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item[i].place()), true,
platform::errors::InvalidArgument(
"Tensor's place of input(X) must be CPUPlace."));
if (deepcopy) {
DeepCopy(src_item[i], fetch_var_name, &dst_item[i]);
} else {
dst_item[i].ShareDataWith(src_item[i]);
}
} }
} }
} }
...@@ -167,6 +166,8 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -167,6 +166,8 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
"(vector<LoDTensor>) A fetching list of LoDTensor which may have " "(vector<LoDTensor>) A fetching list of LoDTensor which may have "
"different dimension, shape and data type."); "different dimension, shape and data type.");
AddAttr<int>("col", "(int) The column index of fetching object."); AddAttr<int>("col", "(int) The column index of fetching object.");
AddAttr<bool>("deepcopy", "(bool) Whether deep copy is required.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
FetchV2 Operator. FetchV2 Operator.
...@@ -192,19 +193,3 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double, ...@@ -192,19 +193,3 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
int64_t, ops::FetchV2Kernel, bool, int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16, ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel); ops::FetchV2Kernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);
#endif
#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_OP_NPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册