未验证 提交 7bc57d35 编写于 作者: K kangguangli 提交者: GitHub

transfer memcpy_h2d from fluid to phi (#44932)

* transfer memcpy_h2d from fluid to phi

* use UnchangedInferMeta instead

* restore test_standalone_executor

* add newline to fix codestyle check

* rename pt -> phi

* simplify logic and add check

* make the comment more clear

* remove useless comment

* refine code
上级 a3eb341e
......@@ -15,6 +15,8 @@
#include "paddle/fluid/framework/new_executor/data_transfer.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
namespace paddle {
namespace framework {
......@@ -110,33 +112,63 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// 2. Execute infer shape and choose kernel
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
op.get()->Info().infer_shape_(&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op_type);
PADDLE_ENFORCE_NE(kernels_iter,
all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in "
"the %s operator.",
op_type));
OpKernelMap& kernels = kernels_iter->second;
// 2. choose kernel
// prepare a ptr to OperatorWithKernel
OperatorBase* op_ptr = op.get();
if (dynamic_cast<framework::OperatorWithKernel*>(op_ptr) == nullptr) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"%s should be OperatorWithKernel type.", op_ptr->Type()));
}
auto op_with_kernel = static_cast<framework::OperatorWithKernel*>(op_ptr);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place_);
Scope scope;
auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context);
auto expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(op.get())
->GetExpectedKernelType(exec_ctx);
auto kernel_iter = kernels.find(expected_kernel_key);
auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context);
auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx);
VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n";
VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n";
bool run_phi_kernel = false;
// check if phi kernel exists
auto phi_kernel_map =
phi::KernelFactory::Instance().SelectKernelMap(op_with_kernel->Type());
if (phi_kernel_map.size() > 0) {
auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx);
VLOG(6) << "phi_kernel_key " << phi_kernel_key << "\n";
// this function is used to construct data transfer op
// we expect that it always has a valid phi kernel
// so no need to fallback to cpu kernel
PADDLE_ENFORCE_EQ(
op_with_kernel->PhiKernel()->IsValid(),
true,
platform::errors::PreconditionNotMet(
"the %s op has no valid phi kernel.", op_with_kernel->Type()));
run_phi_kernel = true;
}
// 3. Execute transfer op and construct OpFuncNode
OpFuncNode new_op_func_node;
new_op_func_node.input_index["X"] = {var_scope_->VarId(var_name)};
new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)};
new_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
new_op_func_node.kernel_func_(exec_ctx);
if (!run_phi_kernel) {
op_with_kernel->ChooseKernel(exec_ctx);
new_op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
new_op_func_node.kernel_func_(exec_ctx);
} else {
new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context);
(*new_op_func_node.phi_kernel_)(&phi_kernel_context);
}
// NOTE(winter-wang): in npu device, D2H kernel is asynchronous. need to
// explicit synchronization.
#ifdef PADDLE_WITH_ASCEND_CL
......
......@@ -13,6 +13,11 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -32,17 +37,6 @@ class MemcpyH2DOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
......@@ -117,95 +111,18 @@ raise error if the type is not listed above.
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(memcpy_h2d,
MemcpyH2DInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(
memcpy_h2d,
ops::MemcpyH2DOp,
ops::MemcpyH2DOpProtoMaker,
ops::MemcpyH2DInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy_h2d,
float,
ops::MemcpyH2DKernel,
double,
ops::MemcpyH2DKernel,
int8_t,
ops::MemcpyH2DKernel,
uint8_t,
ops::MemcpyH2DKernel,
int,
ops::MemcpyH2DKernel,
int64_t,
ops::MemcpyH2DKernel,
bool,
ops::MemcpyH2DKernel,
paddle::platform::bfloat16,
ops::MemcpyH2DKernel,
paddle::platform::complex<float>,
ops::MemcpyH2DKernel,
paddle::platform::complex<double>,
ops::MemcpyH2DKernel,
plat::float16,
ops::MemcpyH2DKernel,
int16_t,
ops::MemcpyH2DKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy_h2d,
float,
ops::MemcpyH2DKernel,
double,
ops::MemcpyH2DKernel,
int8_t,
ops::MemcpyH2DKernel,
uint8_t,
ops::MemcpyH2DKernel,
int,
ops::MemcpyH2DKernel,
int64_t,
ops::MemcpyH2DKernel,
bool,
ops::MemcpyH2DKernel,
paddle::platform::bfloat16,
ops::MemcpyH2DKernel,
paddle::platform::complex<float>,
ops::MemcpyH2DKernel,
paddle::platform::complex<double>,
ops::MemcpyH2DKernel,
plat::float16,
ops::MemcpyH2DKernel,
int16_t,
ops::MemcpyH2DKernel);
#endif
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL_FUNCTOR(memcpy_h2d,
float,
ops::MemcpyH2DKernel,
double,
ops::MemcpyH2DKernel,
int8_t,
ops::MemcpyH2DKernel,
uint8_t,
ops::MemcpyH2DKernel,
int,
ops::MemcpyH2DKernel,
int64_t,
ops::MemcpyH2DKernel,
bool,
ops::MemcpyH2DKernel,
paddle::platform::bfloat16,
ops::MemcpyH2DKernel,
paddle::platform::complex<float>,
ops::MemcpyH2DKernel,
paddle::platform::complex<double>,
ops::MemcpyH2DKernel,
plat::float16,
ops::MemcpyH2DKernel,
int16_t,
ops::MemcpyH2DKernel);
#endif
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
MemcpyH2DInferShapeFunctor);
#ifdef PADDLE_WITH_ASCEND_CL
REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_h2d,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/memcpy_h2d_kernel.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void MemcpyH2DKernel(const Context& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out) {
PADDLE_ENFORCE_GE(
dst_place_type,
0,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
PADDLE_ENFORCE_LE(
dst_place_type,
3,
errors::OutOfRange("dst_place_type only support 0-3, but got: %d",
dst_place_type));
// Copy will set the stream of the tensor while setting blocking to false
Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
} // namespace phi
PD_REGISTER_KERNEL(memcpy_h2d,
CPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel,
float,
double,
int8_t,
uint8_t,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
int16_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(memcpy_h2d,
GPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel,
float,
double,
int8_t,
uint8_t,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
int16_t) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(memcpy_h2d,
XPU,
ALL_LAYOUT,
phi::MemcpyH2DKernel,
float,
double,
int8_t,
uint8_t,
int,
int64_t,
bool,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16,
int16_t) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// used in new executor, for memory copy from host to device
template <typename T, typename Context>
void MemcpyH2DKernel(const Context& dev_ctx,
const DenseTensor& x,
int dst_place_type,
DenseTensor* out);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册