未验证 提交 c56d6978 编写于 作者: W wanghuancoder 提交者: GitHub

modify fetch logic, use D2H Stream (#35191)

* modify fetch logic, use D2H Stream, test=develop

* refine, test=develop
上级 7743cdf2
......@@ -143,8 +143,7 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
main_program_(main_prog),
global_scope_(global_scope),
d2h_ctx_pool_({place}),
h2d_ctx_pool_({place}),
fetch_context_pool_({place}) {
h2d_ctx_pool_({place}) {
is_build_ = false;
garbages_.reset(new GarbageQueue());
......@@ -339,9 +338,6 @@ void InterpreterCore::BuildInstructionCtx(Instruction* instr_node,
new RuntimeInferShapeContext(*op_base, *instr_node->runtime_ctx_.get()));
auto* dev_ctx = instr_node->dev_ctx_;
if (instr_node->kernel_func_.operator_base_->Type() == "fetch_v2") {
dev_ctx = fetch_context_pool_.Get(place);
}
Scope scope;
instr_node->execution_ctx_.reset(new ExecutionContext(
......@@ -356,12 +352,6 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
instr_node.kernel_func_.operator_base_)
->InferShape(instr_node.infershape_ctx_.get());
if (instr_node.kernel_func_.operator_base_->Type() == "fetch_v2") {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place_);
dev_ctx->Wait(); // TODO(wanghuancoder)
}
instr_node.kernel_func_.compute_func_(*instr_node.execution_ctx_.get());
}
......@@ -411,8 +401,6 @@ void InterpreterCore::ExecuteInstructionList(
working_var_ref);
}
fetch_context_pool_.Get(place)->Wait();
for (size_t i = 0; i < working_var_ref.size(); ++i) {
if (working_var_ref[i].var_ref_count_ != 0) {
std::cerr << " var ref is not zero " << i << std::endl;
......@@ -671,6 +659,9 @@ void InterpreterCore::BuildOpFuncList(const platform::Place& place,
expected_kernel_key);
if (!platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_)) {
if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false);
}
// need trans place
// 1. add var in scope
// 2. add copy op
......
......@@ -114,8 +114,6 @@ class InterpreterCore {
size_t max_memory_size_;
size_t cur_memory_size_;
std::unique_ptr<WorkQueue> gc_queue_;
platform::DeviceContextPool fetch_context_pool_;
};
} // namespace framework
} // namespace paddle
......@@ -36,10 +36,9 @@ struct float16;
namespace paddle {
namespace operators {
static void DataCopy(const framework::LoDTensor &src_item,
static void DeepCopy(const framework::LoDTensor &src_item,
const std::string &fetch_var_name,
framework::LoDTensor *dst_item,
const platform::DeviceContext &dev_ctx) {
framework::LoDTensor *dst_item) {
if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle
......@@ -53,26 +52,13 @@ static void DataCopy(const framework::LoDTensor &src_item,
: paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout(),
src_item, &out, platform::CPUPlace());
TensorCopy(src_item, platform::CPUPlace(), dev_ctx, dst_item);
TensorCopySync(out, platform::CPUPlace(), dst_item);
} else {
if (platform::is_gpu_place(src_item.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
#endif
} else {
TensorCopy(src_item, platform::CPUPlace(), dst_item);
}
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
}
#else
if (platform::is_gpu_place(src_item.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TensorCopy(src_item, platform::CUDAPinnedPlace(), dev_ctx, dst_item);
TensorCopySync(src_item, platform::CPUPlace(), dst_item);
#endif
} else {
TensorCopy(src_item, platform::CPUPlace(), dst_item);
}
#endif
} else {
// Not copy, if the src tensor is empty.
dst_item->clear();
......@@ -92,15 +78,14 @@ class FetchV2Op : public framework::OperatorWithKernel {
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
tensor.place(), tensor.layout());
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
platform::CPUPlace());
}
};
......@@ -119,12 +104,10 @@ class FetchV2Kernel {
if (fetch_var == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(ctx.HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of memcpy_d2h_op is not found."));
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of fetch_v2_op is not found."));
auto *out_var = ctx.OutputVar("Out");
// Get dev_ctx from ExecutionContext, it's D2H stream
auto &dev_ctx = ctx.device_context();
int col = ctx.Attr<int>("col");
PADDLE_ENFORCE_GE(
......@@ -140,10 +123,19 @@ class FetchV2Kernel {
fetch_list->resize(col + 1);
}
bool deepcopy = ctx.Attr<bool>("deepcopy");
if (fetch_var->IsType<framework::LoDTensor>()) {
auto &src_item = fetch_var->Get<framework::LoDTensor>();
auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col)));
DataCopy(src_item, fetch_var_name, dst_item, dev_ctx);
PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item.place()), true,
platform::errors::InvalidArgument(
"Tensor's place of input(X) must be CPUPlace."));
if (deepcopy) {
DeepCopy(src_item, fetch_var_name, dst_item);
} else {
dst_item->ShareDataWith(src_item);
}
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
framework::LoDTensorArray tmp(src_item.size());
......@@ -151,7 +143,14 @@ class FetchV2Kernel {
auto &dst_item =
BOOST_GET(framework::LoDTensorArray, fetch_list->at(col));
for (size_t i = 0; i < src_item.size(); ++i) {
DataCopy(src_item[i], fetch_var_name, &dst_item[i], dev_ctx);
PADDLE_ENFORCE_EQ(platform::is_cpu_place(src_item[i].place()), true,
platform::errors::InvalidArgument(
"Tensor's place of input(X) must be CPUPlace."));
if (deepcopy) {
DeepCopy(src_item[i], fetch_var_name, &dst_item[i]);
} else {
dst_item[i].ShareDataWith(src_item[i]);
}
}
}
}
......@@ -167,6 +166,8 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
"(vector<LoDTensor>) A fetching list of LoDTensor which may have "
"different dimension, shape and data type.");
AddAttr<int>("col", "(int) The column index of fetching object.");
AddAttr<bool>("deepcopy", "(bool) Whether deep copy is required.")
.SetDefault(true);
AddComment(R"DOC(
FetchV2 Operator.
......@@ -192,19 +193,3 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_ROCM)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);
#endif
#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_OP_NPU_KERNEL_FUNCTOR(fetch_v2, float, ops::FetchV2Kernel, double,
ops::FetchV2Kernel, int, ops::FetchV2Kernel,
int64_t, ops::FetchV2Kernel, bool,
ops::FetchV2Kernel, plat::float16,
ops::FetchV2Kernel);
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册