From 1904572ac8edb57dfb528e711588758002a168dd Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 17 Mar 2022 21:28:39 +0800 Subject: [PATCH] [Phi] Move assign kernel into phi (#40022) * move assign kernel init commit * change vec to vec * support tensor array * support api declare * fix test_list failed * fix npu and xpu failed * fix infrt failed * remove assign array size in operator * move assign sr header into sr dir * add infermeta for assign * test op success * fix test_list failed * fix kunlun failed * add set host allocator in tests * support tensor array in arg ctx * open set layout in share_meta * fix meta tensor layout error * fix test failed --- paddle/fluid/framework/infershape_utils.cc | 65 ++++++++++++++++--- paddle/fluid/framework/operator.cc | 32 +++++++-- paddle/fluid/framework/operator.h | 4 ++ paddle/fluid/imperative/prepared_operator.h | 26 +++++++- paddle/fluid/operators/assign_op.cc | 61 ++--------------- paddle/fluid/operators/assign_op_npu_test.cc | 2 +- .../dialect/phi/pass/proto_arg_map_context.cc | 4 ++ .../dialect/phi/pass/proto_arg_map_context.h | 1 + paddle/phi/core/compat/arg_map_context.h | 2 + paddle/phi/core/kernel_context.cc | 14 ++++ paddle/phi/core/kernel_context.h | 6 ++ paddle/phi/kernels/assign_kernel.cc | 63 ++++++++++++++++++ paddle/phi/kernels/assign_kernel.h | 34 ++++++++++ paddle/phi/kernels/cpu/copy_kernel.cc | 2 +- .../kernels/selected_rows/assign_kernel.cc | 49 ++++++++++++++ .../phi/kernels/selected_rows/assign_kernel.h | 28 ++++++++ paddle/phi/ops/compat/assign_sig.cc | 35 ++++++++++ paddle/phi/tests/kernels/test_copy_dev_api.cc | 4 ++ .../phi/tests/kernels/test_flatten_dev_api.cc | 4 ++ .../phi/tests/kernels/test_reshape_dev_api.cc | 4 ++ paddle/phi/tests/ops/test_op_signature.h | 5 ++ 21 files changed, 371 insertions(+), 74 deletions(-) create mode 100644 paddle/phi/kernels/assign_kernel.cc create mode 100644 paddle/phi/kernels/assign_kernel.h create mode 100644 paddle/phi/kernels/selected_rows/assign_kernel.cc create mode 100644 paddle/phi/kernels/selected_rows/assign_kernel.h create mode 100644 paddle/phi/ops/compat/assign_sig.cc diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index dec8d1d846..2babecc6dd 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { return var_types[0] == proto::VarType::SELECTED_ROWS; } + bool IsDenseTensorVectorInput(const std::string& name) const override { + auto var_types = ctx_.GetInputsVarType(name); + return var_types[0] == proto::VarType::LOD_TENSOR_ARRAY; + } + bool IsDenseTensorOutput(const std::string& name) const override { auto var_types = ctx_.GetOutputsVarType(name); return var_types[0] == proto::VarType::LOD_TENSOR; @@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor { return var->Get().dims(); } else if (var->IsType()) { return var->Get().dims(); + } else if (var->IsType()) { + // use tensor array size as dims + auto& tensor_array = var->Get(); + return phi::make_ddim({static_cast(tensor_array.size())}); } else { PADDLE_THROW(platform::errors::Unimplemented( - "Currently, only can get dims from DenseTensor or SelectedRows.")); + "Currently, only can get dims from DenseTensor or SelectedRows or " + "DenseTensorArray.")); } } else { auto* var = BOOST_GET_CONST(VarDesc*, var_); @@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor { return var->Get().dtype(); } else if (var->IsType()) { return var->Get().dtype(); + } else if (var->IsType()) { + // NOTE(chenweihang): do nothing + // Unsupported get dtype from LoDTensorArray now + return phi::DataType::UNDEFINED; } else { PADDLE_THROW(platform::errors::Unimplemented( "Currently, only can get dtype from DenseTensor or SelectedRows.")); @@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor { DataLayout layout() const override { if (is_runtime_) { auto* var = BOOST_GET_CONST(Variable*, var_); - return var->Get().layout(); + if (var->IsType()) { + return var->Get().layout(); + } else if (var->IsType()) { + return var->Get().layout(); + } else if (var->IsType()) { + // NOTE(chenweihang): do nothing + // Unsupported get layout from LoDTensorArray now + return phi::DataLayout::UNDEFINED; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can get layout from DenseTensor or " + "SelectedRows.")); + } } else { // NOTE(chenweihang): do nothing // Unsupported get layout for VarDesc now @@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor { } else if (var->IsType()) { auto* tensor = var->GetMutable()->mutable_value(); phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; + } else if (var->IsType()) { + auto* tensor_array = var->GetMutable(); + // Note: Here I want enforce `tensor_array->size() == 0UL`, because + // inplace using on LoDTensorArray is dangerous, but the unittest + // `test_list` contains this behavior + PADDLE_ENFORCE_EQ(dims.size(), 1UL, + platform::errors::InvalidArgument( + "LoDTensorArray can only have one dimension.")); + // only set the array size for LoDTensorArray input + tensor_array->resize(dims[0]); } else { PADDLE_THROW(platform::errors::Unimplemented( "Currently, only can set dims from DenseTensor or SelectedRows.")); @@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor { } else if (var->IsType()) { auto* tensor = var->GetMutable()->mutable_value(); phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; + } else if (var->IsType()) { + // NOTE(chenweihang): do nothing + // Unsupported set dtype for LoDTensorArray now } else { PADDLE_THROW(platform::errors::Unimplemented( "Currently, only can set dtype from DenseTensor or SelectedRows.")); @@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor { void set_layout(DataLayout layout) override { if (is_runtime_) { auto* var = BOOST_GET(Variable*, var_); - LoDTensor* tensor = var->GetMutable(); - phi::DenseTensorUtils::GetMutableMeta( - static_cast(tensor)) - ->layout = layout; + if (var->IsType()) { + auto* tensor = var->GetMutable(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; + } else if (var->IsType()) { + auto* tensor = var->GetMutable()->mutable_value(); + phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout; + } else if (var->IsType()) { + // NOTE(chenweihang): do nothing + // Unsupported set dtype for LoDTensorArray now + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Currently, only can set layout from DenseTensor or " + "SelectedRows.")); + } } else { // NOTE(chenweihang): do nothing // Unsupported set layout for VarDesc now @@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor { void share_meta(const MetaTensor& meta_tensor) override { share_dims(meta_tensor); set_dtype(meta_tensor.dtype()); - // VarDesc doesn't contains layout, so we cannot share layout - // set_layout(meta_tensor.layout()); - + set_layout(meta_tensor.layout()); // special case: share lod of LoDTensor share_lod(meta_tensor); } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index ad01adf1a2..ec28c98d59 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext( auto* var = ins_vector[offset]; if (var->IsType()) { tensor_in = &(var->Get()); + pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } else if (var->IsType()) { tensor_in = &(var->Get()); + pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); + } else if (var->IsType()) { + paddle::SmallVector tensor_vector; + auto& tensor_array = var->Get(); + for (auto& t : tensor_array) { + tensor_vector.emplace_back(&t); + } + pt_kernel_context->EmplaceBackInputsWithoutSetRange(tensor_vector); + end_idx += tensor_array.size() - 1; } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported input `%s` type when call pt kernel.", framework::ToTypeName(var->Type()))); } - - pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in); } + // Note: here cannot deal with vector input pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); } VLOG(4) << "Done inputs"; @@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext( for (size_t offset = 0; offset < outs_vector.size(); ++offset) { phi::TensorBase* tensor_out = nullptr; auto* var = outs_vector[offset]; - if (var) { if (var->template IsType()) { tensor_out = var->template GetMutable(); + pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } else if (var->template IsType()) { tensor_out = var->template GetMutable(); + pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); + } else if (var->template IsType()) { + paddle::SmallVector tensor_vector; + auto* tensor_array = + var->template GetMutable(); + // Note: If the input LoDTensorArray size is 0, the output + // LoDTensorArray is also 0 + for (auto& t : *tensor_array) { + tensor_vector.emplace_back(&t); + } + pt_kernel_context->EmplaceBackOutputsWithoutSetRange(tensor_vector); + end_idx += tensor_array->size() - 1; } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported output `%s` type when call pt kernel.", framework::ToTypeName(var->Type()))); } + } else { + pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } - - pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out); } - pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } VLOG(4) << "Done outputs"; diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 1a1171f1db..6f68c261d2 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -483,6 +483,10 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { return ctx_.InputVar(name)->IsType(); } + bool IsDenseTensorVectorInput(const std::string& name) const override { + return ctx_.InputVar(name)->IsType(); + } + bool IsDenseTensorOutput(const std::string& name) const override { return ctx_.OutputVar(name)->IsType(); } diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 16f2df7924..f70f44878e 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -289,14 +289,23 @@ void BuildDygraphPhiKernelContext( auto& var = ins_vector[offset]->Var(); if (var.template IsType()) { tensor_in = &(var.template Get()); + kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); } else if (var.template IsType()) { tensor_in = &(var.template Get()); + kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); + } else if (var.template IsType()) { + paddle::SmallVector tensor_vector; + auto& tensor_array = var.template Get(); + for (auto& t : tensor_array) { + tensor_vector.emplace_back(&t); + } + kernel_ctx->EmplaceBackInputsWithoutSetRange(tensor_vector); + end_idx += tensor_array.size() - 1; } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported input `%s` type when call pt kernel.", framework::ToTypeName(var.Type()))); } - kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in); } kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); } @@ -326,16 +335,27 @@ void BuildDygraphPhiKernelContext( if (var) { if (var->template IsType()) { tensor_out = var->template GetMutable(); + kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); } else if (var->template IsType()) { tensor_out = var->template GetMutable(); + kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); + } else if (var->template IsType()) { + paddle::SmallVector tensor_vector; + auto* tensor_array = + var->template GetMutable(); + for (auto& t : *tensor_array) { + tensor_vector.emplace_back(&t); + } + kernel_ctx->EmplaceBackOutputsWithoutSetRange(tensor_vector); + end_idx += tensor_array->size() - 1; } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported output `%s` type when call pt kernel.", framework::ToTypeName(var->Type()))); } + } else { + kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); } - - kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out); } kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); } diff --git a/paddle/fluid/operators/assign_op.cc b/paddle/fluid/operators/assign_op.cc index 684ac5bafd..ea6614cbfb 100644 --- a/paddle/fluid/operators/assign_op.cc +++ b/paddle/fluid/operators/assign_op.cc @@ -16,6 +16,9 @@ limitations under the License. */ #include +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/unary.h" namespace paddle { namespace framework { class OpDesc; @@ -36,26 +39,6 @@ class AssignOp : public framework::OperatorWithKernel { const framework::AttributeMap &attrs) : OperatorWithKernel(type, inputs, outputs, attrs) {} - void InferShape(framework::InferShapeContext *ctx) const override { - if (ctx->HasInput("X")) { - auto type = ctx->GetInputsVarType("X")[0]; - if (type == framework::proto::VarType::SELECTED_ROWS || - type == framework::proto::VarType::LOD_TENSOR) { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - if (type == framework::proto::VarType::LOD_TENSOR) { - ctx->ShareLoD("X", /*->*/ "Out"); - } - } else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) { - if (ctx->IsRuntime()) { - // The runtime output shape is determined in kernel. - return; - } else { - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - } - } - } - } - protected: framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, @@ -91,24 +74,6 @@ class AssignInferVarType : public framework::VarTypeInference { } }; -class AssignKernel { - public: - void operator()(const framework::ExecutionContext &ctx) const { - auto *x = ctx.InputVar("X"); - if (x == nullptr) { - return; - } - PADDLE_ENFORCE_EQ( - ctx.HasOutput("Out"), true, - platform::errors::NotFound("Output(Out) of assign_op is not found.")); - auto *out = ctx.OutputVar("Out"); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(ctx.GetPlace()); - - framework::VisitVarType(*x, AssignFunctor(out, dev_ctx)); - } -}; - class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -147,23 +112,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); namespace ops = paddle::operators; namespace plat = paddle::platform; + +DECLARE_INFER_SHAPE_FUNCTOR(assign, AssignInferShapeFunctor, + PD_INFER_META(phi::UnchangedInferMeta)); REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker, ops::AssignGradMaker, ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer, - ops::AssignInferVarType); - -REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, - ops::AssignKernel, int, ops::AssignKernel, - int64_t, ops::AssignKernel, uint8_t, - ops::AssignKernel, bool, ops::AssignKernel, - plat::float16, ops::AssignKernel, plat::bfloat16, - ops::AssignKernel); - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double, - ops::AssignKernel, int, ops::AssignKernel, - int64_t, ops::AssignKernel, uint8_t, - ops::AssignKernel, bool, ops::AssignKernel, - plat::float16, ops::AssignKernel); -#endif + ops::AssignInferVarType, AssignInferShapeFunctor); diff --git a/paddle/fluid/operators/assign_op_npu_test.cc b/paddle/fluid/operators/assign_op_npu_test.cc index b452dea853..b91eb50646 100644 --- a/paddle/fluid/operators/assign_op_npu_test.cc +++ b/paddle/fluid/operators/assign_op_npu_test.cc @@ -29,7 +29,7 @@ limitations under the License. */ namespace f = paddle::framework; namespace p = paddle::platform; -USE_OP(assign); +USE_OP_ITSELF(assign); USE_OP_DEVICE_KERNEL(assign, NPU); template diff --git a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc index 64b1843597..1cd5b5a855 100644 --- a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc +++ b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.cc @@ -60,6 +60,10 @@ bool ProtoArgumentMappingContext::IsSelectedRowsInput( const std::string& name) const { return false; } +bool ProtoArgumentMappingContext::IsDenseTensorVectorInput( + const std::string& name) const { + return false; +} bool ProtoArgumentMappingContext::IsDenseTensorOutput( const std::string& name) const { diff --git a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h index 7d08c32161..5cf2ef9790 100644 --- a/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h +++ b/paddle/infrt/dialect/phi/pass/proto_arg_map_context.h @@ -42,6 +42,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext { bool IsDenseTensorInput(const std::string& name) const override; bool IsSelectedRowsInput(const std::string& name) const override; + bool IsDenseTensorVectorInput(const std::string& name) const override; bool IsDenseTensorOutput(const std::string& name) const override; bool IsSelectedRowsOutput(const std::string& name) const override; diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 25b80279ec..71cec01141 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -89,6 +89,8 @@ class ArgumentMappingContext { virtual bool IsDenseTensorInput(const std::string& name) const = 0; virtual bool IsSelectedRowsInput(const std::string& name) const = 0; + // For compatibility with LoDTensorArray + virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0; virtual bool IsDenseTensorOutput(const std::string& name) const = 0; virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; diff --git a/paddle/phi/core/kernel_context.cc b/paddle/phi/core/kernel_context.cc index a32e0e44f4..234e3528c3 100644 --- a/paddle/phi/core/kernel_context.cc +++ b/paddle/phi/core/kernel_context.cc @@ -37,6 +37,13 @@ void KernelContext::EmplaceBackInputs( std::make_move_iterator(inputs.end())); } +void KernelContext::EmplaceBackInputsWithoutSetRange( + paddle::SmallVector inputs) { + inputs_.insert(inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); +} + void KernelContext::EmplaceBackOutput(TensorBase* output) { int index = outputs_.size(); outputs_.emplace_back(output); @@ -59,6 +66,13 @@ void KernelContext::EmplaceBackOutputs( std::make_move_iterator(outputs.end())); } +void KernelContext::EmplaceBackOutputsWithoutSetRange( + paddle::SmallVector outputs) { + outputs_.insert(outputs_.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); +} + void KernelContext::EmplaceBackAttr(paddle::any attr) { attrs_.emplace_back(std::move(attr)); } diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index 213ac47d30..d3ca1ffc61 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -52,12 +52,18 @@ class KernelContext { void EmplaceBackInputs(paddle::SmallVector inputs); + void EmplaceBackInputsWithoutSetRange( + paddle::SmallVector inputs); + void EmplaceBackOutput(TensorBase* output); void EmplaceBackOutputWithoutSetRange(TensorBase* output); void EmplaceBackOutputs(paddle::SmallVector outputs); + void EmplaceBackOutputsWithoutSetRange( + paddle::SmallVector outputs); + void EmplaceBackAttr(paddle::any attr); const std::pair& InputRangeAt(size_t idx) const; diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc new file mode 100644 index 0000000000..9faaace691 --- /dev/null +++ b/paddle/phi/kernels/assign_kernel.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/assign_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/utils/optional.h" + +namespace phi { + +template +void AssignKernel(const Context& dev_ctx, + paddle::optional x, + DenseTensor* out) { + if (!x.is_initialized()) { + return; + } + auto& x_tensor = *x.get_ptr(); + Copy(dev_ctx, x_tensor, x_tensor.place(), false, out); +} + +// Note: use `const paddle::optional&> x` +// as input if needed +template +void AssignArrayKernel(const Context& dev_ctx, + const std::vector& x, + std::vector out) { + for (size_t i = 0; i < x.size(); ++i) { + AssignKernel(dev_ctx, *x[i], out.at(i)); + } +} + +} // namespace phi + +PD_REGISTER_GENERAL_KERNEL( + assign, CPU, ALL_LAYOUT, phi::AssignKernel, ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL(assign_array, + CPU, + ALL_LAYOUT, + phi::AssignArrayKernel, + ALL_DTYPE) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_GENERAL_KERNEL( + assign, GPU, ALL_LAYOUT, phi::AssignKernel, ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL(assign_array, + GPU, + ALL_LAYOUT, + phi::AssignArrayKernel, + ALL_DTYPE) {} +#endif diff --git a/paddle/phi/kernels/assign_kernel.h b/paddle/phi/kernels/assign_kernel.h new file mode 100644 index 0000000000..7cc06818dc --- /dev/null +++ b/paddle/phi/kernels/assign_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +// In order to be compatible with the `AsDispensable` input in the original +// assign op maker, the input parameter here needs to be dispensable, but +// this looks weird +template +void AssignKernel(const Context& dev_ctx, + paddle::optional x, + DenseTensor* out); + +template +void AssignArrayKernel(const Context& dev_ctx, + const std::vector& x, + std::vector out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/copy_kernel.cc b/paddle/phi/kernels/cpu/copy_kernel.cc index 1af071f23d..fa11fd05bf 100644 --- a/paddle/phi/kernels/cpu/copy_kernel.cc +++ b/paddle/phi/kernels/cpu/copy_kernel.cc @@ -38,7 +38,7 @@ void Copy(const Context& dev_ctx, << src_place; dst->Resize(src.dims()); - auto* dst_ptr = dev_ctx.Alloc(dst, src.dtype()); + auto* dst_ptr = dev_ctx.HostAlloc(dst, src.dtype()); if (src_ptr == dst_ptr) { VLOG(3) << "Skip copy the same data async from " << src_place << " to " diff --git a/paddle/phi/kernels/selected_rows/assign_kernel.cc b/paddle/phi/kernels/selected_rows/assign_kernel.cc new file mode 100644 index 0000000000..fae876facf --- /dev/null +++ b/paddle/phi/kernels/selected_rows/assign_kernel.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/assign_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/assign_kernel.h" + +namespace phi { +namespace sr { + +// Note: use `const paddle::optional x` +// as input if needed +template +void AssignKernel(const Context& dev_ctx, + const SelectedRows& x, + SelectedRows* out) { + out->set_rows(x.rows()); + out->set_height(x.height()); + phi::AssignKernel(dev_ctx, x.value(), out->mutable_value()); +} + +} // namespace sr +} // namespace phi + +PD_REGISTER_GENERAL_KERNEL(assign_sr, + CPU, + ALL_LAYOUT, + phi::sr::AssignKernel, + ALL_DTYPE) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_GENERAL_KERNEL(assign_sr, + GPU, + ALL_LAYOUT, + phi::sr::AssignKernel, + ALL_DTYPE) {} +#endif diff --git a/paddle/phi/kernels/selected_rows/assign_kernel.h b/paddle/phi/kernels/selected_rows/assign_kernel.h new file mode 100644 index 0000000000..2ba465615a --- /dev/null +++ b/paddle/phi/kernels/selected_rows/assign_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/selected_rows.h" + +namespace phi { +namespace sr { + +template +void AssignKernel(const Context& dev_ctx, + const SelectedRows& x, + SelectedRows* out); + +} // namespace sr +} // namespace phi diff --git a/paddle/phi/ops/compat/assign_sig.cc b/paddle/phi/ops/compat/assign_sig.cc new file mode 100644 index 0000000000..d149e8e6a9 --- /dev/null +++ b/paddle/phi/ops/compat/assign_sig.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.HasInput("X")) { + if (ctx.IsDenseTensorVectorInput("X")) { + return KernelSignature("assign_array", {"X"}, {}, {"Out"}); + } else if (ctx.IsSelectedRowsInput("X")) { + return KernelSignature("assign_sr", {"X"}, {}, {"Out"}); + } else { + return KernelSignature("assign", {"X"}, {}, {"Out"}); + } + } else { + return KernelSignature("assign", {"X"}, {}, {"Out"}); + } +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(assign, phi::AssignOpArgumentMapping); diff --git a/paddle/phi/tests/kernels/test_copy_dev_api.cc b/paddle/phi/tests/kernels/test_copy_dev_api.cc index d69c7b2174..460d85f831 100644 --- a/paddle/phi/tests/kernels/test_copy_dev_api.cc +++ b/paddle/phi/tests/kernels/test_copy_dev_api.cc @@ -61,6 +61,10 @@ TEST(DEV_API, copy) { dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(paddle::platform::CPUPlace()) .get()); + dev_ctx.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); dev_ctx.Init(); phi::Copy( dev_ctx, *(dense_src.get()), phi::CPUPlace(), false, dense_dst.get()); diff --git a/paddle/phi/tests/kernels/test_flatten_dev_api.cc b/paddle/phi/tests/kernels/test_flatten_dev_api.cc index dc283728ee..e3f2e8b57e 100644 --- a/paddle/phi/tests/kernels/test_flatten_dev_api.cc +++ b/paddle/phi/tests/kernels/test_flatten_dev_api.cc @@ -58,6 +58,10 @@ TEST(DEV_API, flatten) { dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(paddle::platform::CPUPlace()) .get()); + dev_ctx.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); dev_ctx.Init(); // 2. test API diff --git a/paddle/phi/tests/kernels/test_reshape_dev_api.cc b/paddle/phi/tests/kernels/test_reshape_dev_api.cc index 16ad4fc341..7de039372f 100644 --- a/paddle/phi/tests/kernels/test_reshape_dev_api.cc +++ b/paddle/phi/tests/kernels/test_reshape_dev_api.cc @@ -50,6 +50,10 @@ TEST(DEV_API, reshape) { dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() .GetAllocator(paddle::platform::CPUPlace()) .get()); + dev_ctx.SetHostAllocator( + paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); dev_ctx.Init(); auto out = phi::Reshape(dev_ctx, dense_x, shape); // 3. check result diff --git a/paddle/phi/tests/ops/test_op_signature.h b/paddle/phi/tests/ops/test_op_signature.h index 06048f33d9..8468dad10e 100644 --- a/paddle/phi/tests/ops/test_op_signature.h +++ b/paddle/phi/tests/ops/test_op_signature.h @@ -72,6 +72,11 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { return selected_rows_inputs.count(name) > 0; } + // add member if needed + bool IsDenseTensorVectorInput(const std::string& name) const override { + return false; + } + bool IsDenseTensorOutput(const std::string& name) const override { return dense_tensor_outputs.count(name) > 0; } -- GitLab