未验证 提交 1904572a 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move assign kernel into phi (#40022)

* move assign kernel init commit

* change vec<tensor> to vec<tensor*>

* support tensor array

* support api declare

* fix test_list failed

* fix npu and xpu failed

* fix infrt failed

* remove assign array size in operator

* move assign sr header into sr dir

* add infermeta for assign

* test op success

* fix test_list failed

* fix kunlun failed

* add set host allocator in tests

* support tensor array in arg ctx

* open set layout in share_meta

* fix meta tensor layout error

* fix test failed
上级 31776199
......@@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS;
}
bool IsDenseTensorVectorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR_ARRAY;
}
bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR;
......@@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows."));
"Currently, only can get dims from DenseTensor or SelectedRows or "
"DenseTensorArray."));
}
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
......@@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
return phi::DataType::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dtype from DenseTensor or SelectedRows."));
......@@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor {
DataLayout layout() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().layout();
if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
}
} else {
// NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now
......@@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
// inplace using on LoDTensorArray is dangerous, but the unittest
// `test_list` contains this behavior
PADDLE_ENFORCE_EQ(dims.size(), 1UL,
platform::errors::InvalidArgument(
"LoDTensorArray can only have one dimension."));
// only set the array size for LoDTensorArray input
tensor_array->resize(dims[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dims from DenseTensor or SelectedRows."));
......@@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dtype from DenseTensor or SelectedRows."));
......@@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor {
void set_layout(DataLayout layout) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
phi::DenseTensorUtils::GetMutableMeta(
static_cast<phi::DenseTensor*>(tensor))
->layout = layout;
if (var->IsType<phi::DenseTensor>()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
}
} else {
// NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now
......@@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override {
share_dims(meta_tensor);
set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout
// set_layout(meta_tensor.layout());
set_layout(meta_tensor.layout());
// special case: share lod of LoDTensor
share_lod(meta_tensor);
}
......
......@@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext(
auto* var = ins_vector[offset];
if (var->IsType<framework::LoDTensor>()) {
tensor_in = &(var->Get<framework::LoDTensor>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<phi::SelectedRows>()) {
tensor_in = &(var->Get<phi::SelectedRows>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var->Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackInputsWithoutSetRange(tensor_vector);
end_idx += tensor_array.size() - 1;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
}
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
}
// Note: here cannot deal with vector<LoDTensorArray> input
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}
VLOG(4) << "Done inputs";
......@@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset];
if (var) {
if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<phi::TensorBase*> tensor_vector;
auto* tensor_array =
var->template GetMutable<framework::LoDTensorArray>();
// Note: If the input LoDTensorArray size is 0, the output
// LoDTensorArray is also 0
for (auto& t : *tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackOutputsWithoutSetRange(tensor_vector);
end_idx += tensor_array->size() - 1;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
}
} else {
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
}
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
}
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
}
VLOG(4) << "Done outputs";
......
......@@ -483,6 +483,10 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return ctx_.InputVar(name)->IsType<phi::SelectedRows>();
}
bool IsDenseTensorVectorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::LoDTensorArray>();
}
bool IsDenseTensorOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<framework::LoDTensor>();
}
......
......@@ -289,14 +289,23 @@ void BuildDygraphPhiKernelContext(
auto& var = ins_vector[offset]->Var();
if (var.template IsType<phi::DenseTensor>()) {
tensor_in = &(var.template Get<phi::DenseTensor>());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var.template IsType<phi::SelectedRows>()) {
tensor_in = &(var.template Get<phi::SelectedRows>());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var.template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var.template Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
tensor_vector.emplace_back(&t);
}
kernel_ctx->EmplaceBackInputsWithoutSetRange(tensor_vector);
end_idx += tensor_array.size() - 1;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var.Type())));
}
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
}
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
}
......@@ -326,16 +335,27 @@ void BuildDygraphPhiKernelContext(
if (var) {
if (var->template IsType<phi::DenseTensor>()) {
tensor_out = var->template GetMutable<phi::DenseTensor>();
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<phi::TensorBase*> tensor_vector;
auto* tensor_array =
var->template GetMutable<framework::LoDTensorArray>();
for (auto& t : *tensor_array) {
tensor_vector.emplace_back(&t);
}
kernel_ctx->EmplaceBackOutputsWithoutSetRange(tensor_vector);
end_idx += tensor_array->size() - 1;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type())));
}
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
}
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
}
......
......@@ -16,6 +16,9 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace framework {
class OpDesc;
......@@ -36,26 +39,6 @@ class AssignOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasInput("X")) {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// The runtime output shape is determined in kernel.
return;
} else {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
}
}
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
......@@ -91,24 +74,6 @@ class AssignInferVarType : public framework::VarTypeInference {
}
};
class AssignKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.InputVar("X");
if (x == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of assign_op is not found."));
auto *out = ctx.OutputVar("Out");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
}
};
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -147,23 +112,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(assign, AssignInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(assign, ops::AssignOp,
ops::AssignGradMaker<paddle::framework::OpDesc>,
ops::AssignGradMaker<paddle::imperative::OpBase>,
ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer,
ops::AssignInferVarType);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, uint8_t,
ops::AssignKernel, bool, ops::AssignKernel,
plat::float16, ops::AssignKernel, plat::bfloat16,
ops::AssignKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, uint8_t,
ops::AssignKernel, bool, ops::AssignKernel,
plat::float16, ops::AssignKernel);
#endif
ops::AssignInferVarType, AssignInferShapeFunctor);
......@@ -29,7 +29,7 @@ limitations under the License. */
namespace f = paddle::framework;
namespace p = paddle::platform;
USE_OP(assign);
USE_OP_ITSELF(assign);
USE_OP_DEVICE_KERNEL(assign, NPU);
template <typename T>
......
......@@ -60,6 +60,10 @@ bool ProtoArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const {
return false;
}
bool ProtoArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
}
bool ProtoArgumentMappingContext::IsDenseTensorOutput(
const std::string& name) const {
......
......@@ -42,6 +42,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsDenseTensorInput(const std::string& name) const override;
bool IsSelectedRowsInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override;
bool IsSelectedRowsOutput(const std::string& name) const override;
......
......@@ -89,6 +89,8 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
......
......@@ -37,6 +37,13 @@ void KernelContext::EmplaceBackInputs(
std::make_move_iterator(inputs.end()));
}
void KernelContext::EmplaceBackInputsWithoutSetRange(
paddle::SmallVector<const TensorBase*> inputs) {
inputs_.insert(inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void KernelContext::EmplaceBackOutput(TensorBase* output) {
int index = outputs_.size();
outputs_.emplace_back(output);
......@@ -59,6 +66,13 @@ void KernelContext::EmplaceBackOutputs(
std::make_move_iterator(outputs.end()));
}
void KernelContext::EmplaceBackOutputsWithoutSetRange(
paddle::SmallVector<TensorBase*> outputs) {
outputs_.insert(outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
void KernelContext::EmplaceBackAttr(paddle::any attr) {
attrs_.emplace_back(std::move(attr));
}
......
......@@ -52,12 +52,18 @@ class KernelContext {
void EmplaceBackInputs(paddle::SmallVector<const TensorBase*> inputs);
void EmplaceBackInputsWithoutSetRange(
paddle::SmallVector<const TensorBase*> inputs);
void EmplaceBackOutput(TensorBase* output);
void EmplaceBackOutputWithoutSetRange(TensorBase* output);
void EmplaceBackOutputs(paddle::SmallVector<TensorBase*> outputs);
void EmplaceBackOutputsWithoutSetRange(
paddle::SmallVector<TensorBase*> outputs);
void EmplaceBackAttr(paddle::any attr);
const std::pair<int, int>& InputRangeAt(size_t idx) const;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/assign_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out) {
if (!x.is_initialized()) {
return;
}
auto& x_tensor = *x.get_ptr();
Copy<Context>(dev_ctx, x_tensor, x_tensor.place(), false, out);
}
// Note: use `const paddle::optional<std::vector<const DenseTensor*>&> x`
// as input if needed
template <typename Context>
void AssignArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
std::vector<DenseTensor*> out) {
for (size_t i = 0; i < x.size(); ++i) {
AssignKernel<Context>(dev_ctx, *x[i], out.at(i));
}
}
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_array,
CPU,
ALL_LAYOUT,
phi::AssignArrayKernel<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_array,
GPU,
ALL_LAYOUT,
phi::AssignArrayKernel<phi::GPUContext>,
ALL_DTYPE) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// In order to be compatible with the `AsDispensable` input in the original
// assign op maker, the input parameter here needs to be dispensable, but
// this looks weird
template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out);
template <typename Context>
void AssignArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
std::vector<DenseTensor*> out);
} // namespace phi
......@@ -38,7 +38,7 @@ void Copy(const Context& dev_ctx,
<< src_place;
dst->Resize(src.dims());
auto* dst_ptr = dev_ctx.Alloc(dst, src.dtype());
auto* dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/assign_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/assign_kernel.h"
namespace phi {
namespace sr {
// Note: use `const paddle::optional<const SelectedRows&> x`
// as input if needed
template <typename Context>
void AssignKernel(const Context& dev_ctx,
const SelectedRows& x,
SelectedRows* out) {
out->set_rows(x.rows());
out->set_height(x.height());
phi::AssignKernel<Context>(dev_ctx, x.value(), out->mutable_value());
}
} // namespace sr
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(assign_sr,
CPU,
ALL_LAYOUT,
phi::sr::AssignKernel<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(assign_sr,
GPU,
ALL_LAYOUT,
phi::sr::AssignKernel<phi::GPUContext>,
ALL_DTYPE) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename Context>
void AssignKernel(const Context& dev_ctx,
const SelectedRows& x,
SelectedRows* out);
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("X")) {
if (ctx.IsDenseTensorVectorInput("X")) {
return KernelSignature("assign_array", {"X"}, {}, {"Out"});
} else if (ctx.IsSelectedRowsInput("X")) {
return KernelSignature("assign_sr", {"X"}, {}, {"Out"});
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
}
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(assign, phi::AssignOpArgumentMapping);
......@@ -61,6 +61,10 @@ TEST(DEV_API, copy) {
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init();
phi::Copy(
dev_ctx, *(dense_src.get()), phi::CPUPlace(), false, dense_dst.get());
......
......@@ -58,6 +58,10 @@ TEST(DEV_API, flatten) {
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init();
// 2. test API
......
......@@ -50,6 +50,10 @@ TEST(DEV_API, reshape) {
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init();
auto out = phi::Reshape<float>(dev_ctx, dense_x, shape);
// 3. check result
......
......@@ -72,6 +72,11 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return selected_rows_inputs.count(name) > 0;
}
// add member if needed
bool IsDenseTensorVectorInput(const std::string& name) const override {
return false;
}
bool IsDenseTensorOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册