diff --git a/paddle/fluid/framework/new_executor/data_transfer.cc b/paddle/fluid/framework/new_executor/data_transfer.cc index 3cf16266baf088dd029c7f9f2fe3b08171459195..33e50b249ad451371a970262c9d4ec24dc50f64d 100644 --- a/paddle/fluid/framework/new_executor/data_transfer.cc +++ b/paddle/fluid/framework/new_executor/data_transfer.cc @@ -15,6 +15,8 @@ #include "paddle/fluid/framework/new_executor/data_transfer.h" #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/phi/core/kernel_context.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace framework { @@ -110,33 +112,63 @@ void DataTranferHelper::RunAndConstructOpFuncNode( runtime_context.inputs["X"] = {scope_->FindVar(var_name)}; runtime_context.outputs["Out"] = {scope_->Var(new_var_name)}; InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); - - // 2. Execute infer shape and choose kernel - auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); op.get()->Info().infer_shape_(&infer_shape_ctx); - auto kernels_iter = all_op_kernels.find(op_type); - PADDLE_ENFORCE_NE(kernels_iter, - all_op_kernels.end(), - platform::errors::Unavailable( - "There are no kernels which are registered in " - "the %s operator.", - op_type)); - OpKernelMap& kernels = kernels_iter->second; + + // 2. choose kernel + + // prepare a ptr to OperatorWithKernel + OperatorBase* op_ptr = op.get(); + if (dynamic_cast(op_ptr) == nullptr) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "%s should be OperatorWithKernel type.", op_ptr->Type())); + } + auto op_with_kernel = static_cast(op_ptr); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place_); - Scope scope; - auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context); - auto expected_kernel_key = - dynamic_cast(op.get()) - ->GetExpectedKernelType(exec_ctx); - auto kernel_iter = kernels.find(expected_kernel_key); + auto exec_ctx = ExecutionContext(*op, Scope(), *dev_ctx, runtime_context); + auto expected_kernel_key = op_with_kernel->GetExpectedKernelType(exec_ctx); + + VLOG(6) << "expected_kernel_key " << expected_kernel_key << "\n"; + VLOG(6) << "op_with_kernel Type() " << op_with_kernel->Type() << "\n"; + + bool run_phi_kernel = false; + + // check if phi kernel exists + auto phi_kernel_map = + phi::KernelFactory::Instance().SelectKernelMap(op_with_kernel->Type()); + if (phi_kernel_map.size() > 0) { + auto phi_kernel_key = op_with_kernel->ChoosePhiKernel(exec_ctx); + VLOG(6) << "phi_kernel_key " << phi_kernel_key << "\n"; + + // this function is used to construct data transfer op + // we expect that it always has a valid phi kernel + // so no need to fallback to cpu kernel + PADDLE_ENFORCE_EQ( + op_with_kernel->PhiKernel()->IsValid(), + true, + platform::errors::PreconditionNotMet( + "the %s op has no valid phi kernel.", op_with_kernel->Type())); + run_phi_kernel = true; + } // 3. Execute transfer op and construct OpFuncNode OpFuncNode new_op_func_node; new_op_func_node.input_index["X"] = {var_scope_->VarId(var_name)}; new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)}; - new_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); - new_op_func_node.kernel_func_(exec_ctx); + + if (!run_phi_kernel) { + op_with_kernel->ChooseKernel(exec_ctx); + new_op_func_node.kernel_func_ = *op_with_kernel->kernel_func(); + new_op_func_node.kernel_func_(exec_ctx); + } else { + new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel(); + phi::KernelContext phi_kernel_context; + op_with_kernel->BuildPhiKernelContext( + runtime_context, dev_ctx, &phi_kernel_context); + (*new_op_func_node.phi_kernel_)(&phi_kernel_context); + } + // NOTE(winter-wang): in npu device, D2H kernel is asynchronous. need to // explicit synchronization. #ifdef PADDLE_WITH_ASCEND_CL diff --git a/paddle/fluid/operators/memcpy_h2d_op.cc b/paddle/fluid/operators/memcpy_h2d_op.cc index ff7b786d04018e0961832a5fe914c58687aa8450..b1126fb12818ee65a50e7a8b1b7c710ebbae6d7d 100644 --- a/paddle/fluid/operators/memcpy_h2d_op.cc +++ b/paddle/fluid/operators/memcpy_h2d_op.cc @@ -13,6 +13,11 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" + namespace paddle { namespace framework { class OpDesc; @@ -32,17 +37,6 @@ class MemcpyH2DOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - auto type = ctx->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::SELECTED_ROWS || - type == framework::proto::VarType::LOD_TENSOR) { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - if (type == framework::proto::VarType::LOD_TENSOR) { - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - } - protected: framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, @@ -117,95 +111,18 @@ raise error if the type is not listed above. namespace ops = paddle::operators; namespace plat = paddle::platform; + +DECLARE_INFER_SHAPE_FUNCTOR(memcpy_h2d, + MemcpyH2DInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR( memcpy_h2d, ops::MemcpyH2DOp, ops::MemcpyH2DOpProtoMaker, ops::MemcpyH2DInferVarType, paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy_h2d, - float, - ops::MemcpyH2DKernel, - double, - ops::MemcpyH2DKernel, - int8_t, - ops::MemcpyH2DKernel, - uint8_t, - ops::MemcpyH2DKernel, - int, - ops::MemcpyH2DKernel, - int64_t, - ops::MemcpyH2DKernel, - bool, - ops::MemcpyH2DKernel, - paddle::platform::bfloat16, - ops::MemcpyH2DKernel, - paddle::platform::complex, - ops::MemcpyH2DKernel, - paddle::platform::complex, - ops::MemcpyH2DKernel, - plat::float16, - ops::MemcpyH2DKernel, - int16_t, - ops::MemcpyH2DKernel); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy_h2d, - float, - ops::MemcpyH2DKernel, - double, - ops::MemcpyH2DKernel, - int8_t, - ops::MemcpyH2DKernel, - uint8_t, - ops::MemcpyH2DKernel, - int, - ops::MemcpyH2DKernel, - int64_t, - ops::MemcpyH2DKernel, - bool, - ops::MemcpyH2DKernel, - paddle::platform::bfloat16, - ops::MemcpyH2DKernel, - paddle::platform::complex, - ops::MemcpyH2DKernel, - paddle::platform::complex, - ops::MemcpyH2DKernel, - plat::float16, - ops::MemcpyH2DKernel, - int16_t, - ops::MemcpyH2DKernel); -#endif - -#ifdef PADDLE_WITH_XPU -REGISTER_OP_XPU_KERNEL_FUNCTOR(memcpy_h2d, - float, - ops::MemcpyH2DKernel, - double, - ops::MemcpyH2DKernel, - int8_t, - ops::MemcpyH2DKernel, - uint8_t, - ops::MemcpyH2DKernel, - int, - ops::MemcpyH2DKernel, - int64_t, - ops::MemcpyH2DKernel, - bool, - ops::MemcpyH2DKernel, - paddle::platform::bfloat16, - ops::MemcpyH2DKernel, - paddle::platform::complex, - ops::MemcpyH2DKernel, - paddle::platform::complex, - ops::MemcpyH2DKernel, - plat::float16, - ops::MemcpyH2DKernel, - int16_t, - ops::MemcpyH2DKernel); -#endif + paddle::framework::EmptyGradOpMaker, + MemcpyH2DInferShapeFunctor); #ifdef PADDLE_WITH_ASCEND_CL REGISTER_OP_NPU_KERNEL_FUNCTOR(memcpy_h2d, diff --git a/paddle/phi/kernels/memcpy_h2d_kernel.cc b/paddle/phi/kernels/memcpy_h2d_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d20475b9fa65e0d4a43342d7eae9a266cd2b854 --- /dev/null +++ b/paddle/phi/kernels/memcpy_h2d_kernel.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/memcpy_h2d_kernel.h" + +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MemcpyH2DKernel(const Context& dev_ctx, + const DenseTensor& x, + int dst_place_type, + DenseTensor* out) { + PADDLE_ENFORCE_GE( + dst_place_type, + 0, + errors::OutOfRange("dst_place_type only support 0-3, but got: %d", + dst_place_type)); + PADDLE_ENFORCE_LE( + dst_place_type, + 3, + errors::OutOfRange("dst_place_type only support 0-3, but got: %d", + dst_place_type)); + + // Copy will set the stream of the tensor while setting blocking to false + Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(memcpy_h2d, + CPU, + ALL_LAYOUT, + phi::MemcpyH2DKernel, + float, + double, + int8_t, + uint8_t, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::float16, + int16_t) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_KERNEL(memcpy_h2d, + GPU, + ALL_LAYOUT, + phi::MemcpyH2DKernel, + float, + double, + int8_t, + uint8_t, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::float16, + int16_t) {} +#endif + +#ifdef PADDLE_WITH_XPU +PD_REGISTER_KERNEL(memcpy_h2d, + XPU, + ALL_LAYOUT, + phi::MemcpyH2DKernel, + float, + double, + int8_t, + uint8_t, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex, + phi::dtype::float16, + int16_t) {} +#endif diff --git a/paddle/phi/kernels/memcpy_h2d_kernel.h b/paddle/phi/kernels/memcpy_h2d_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..18b642cabe822e84ca791e3b35067a968df73fe1 --- /dev/null +++ b/paddle/phi/kernels/memcpy_h2d_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +// used in new executor, for memory copy from host to device +template +void MemcpyH2DKernel(const Context& dev_ctx, + const DenseTensor& x, + int dst_place_type, + DenseTensor* out); + +} // namespace phi