未验证 提交 1904572a 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move assign kernel into phi (#40022)

* move assign kernel init commit

* change vec<tensor> to vec<tensor*>

* support tensor array

* support api declare

* fix test_list failed

* fix npu and xpu failed

* fix infrt failed

* remove assign array size in operator

* move assign sr header into sr dir

* add infermeta for assign

* test op success

* fix test_list failed

* fix kunlun failed

* add set host allocator in tests

* support tensor array in arg ctx

* open set layout in share_meta

* fix meta tensor layout error

* fix test failed
上级 31776199
...@@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -78,6 +78,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
return var_types[0] == proto::VarType::SELECTED_ROWS; return var_types[0] == proto::VarType::SELECTED_ROWS;
} }
bool IsDenseTensorVectorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR_ARRAY;
}
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name); auto var_types = ctx_.GetOutputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR; return var_types[0] == proto::VarType::LOD_TENSOR;
...@@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -125,9 +130,14 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dims(); return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims(); return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
return phi::make_ddim({static_cast<int64_t>(tensor_array.size())});
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dims from DenseTensor or SelectedRows.")); "Currently, only can get dims from DenseTensor or SelectedRows or "
"DenseTensorArray."));
} }
} else { } else {
auto* var = BOOST_GET_CONST(VarDesc*, var_); auto* var = BOOST_GET_CONST(VarDesc*, var_);
...@@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -144,6 +154,10 @@ class CompatMetaTensor : public phi::MetaTensor {
return var->Get<phi::DenseTensor>().dtype(); return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype(); return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
return phi::DataType::UNDEFINED;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get dtype from DenseTensor or SelectedRows.")); "Currently, only can get dtype from DenseTensor or SelectedRows."));
...@@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -157,7 +171,19 @@ class CompatMetaTensor : public phi::MetaTensor {
DataLayout layout() const override { DataLayout layout() const override {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_); auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().layout(); if (var->IsType<phi::DenseTensor>()) {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
return phi::DataLayout::UNDEFINED;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can get layout from DenseTensor or "
"SelectedRows."));
}
} else { } else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported get layout for VarDesc now // Unsupported get layout for VarDesc now
...@@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -174,6 +200,16 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims; phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
// inplace using on LoDTensorArray is dangerous, but the unittest
// `test_list` contains this behavior
PADDLE_ENFORCE_EQ(dims.size(), 1UL,
platform::errors::InvalidArgument(
"LoDTensorArray can only have one dimension."));
// only set the array size for LoDTensorArray input
tensor_array->resize(dims[0]);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dims from DenseTensor or SelectedRows.")); "Currently, only can set dims from DenseTensor or SelectedRows."));
...@@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -193,6 +229,9 @@ class CompatMetaTensor : public phi::MetaTensor {
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value(); auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype; phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set dtype from DenseTensor or SelectedRows.")); "Currently, only can set dtype from DenseTensor or SelectedRows."));
...@@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -206,10 +245,20 @@ class CompatMetaTensor : public phi::MetaTensor {
void set_layout(DataLayout layout) override { void set_layout(DataLayout layout) override {
if (is_runtime_) { if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_); auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); if (var->IsType<phi::DenseTensor>()) {
phi::DenseTensorUtils::GetMutableMeta( auto* tensor = var->GetMutable<phi::DenseTensor>();
static_cast<phi::DenseTensor*>(tensor)) phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
->layout = layout; } else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Currently, only can set layout from DenseTensor or "
"SelectedRows."));
}
} else { } else {
// NOTE(chenweihang): do nothing // NOTE(chenweihang): do nothing
// Unsupported set layout for VarDesc now // Unsupported set layout for VarDesc now
...@@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor { ...@@ -251,9 +300,7 @@ class CompatMetaTensor : public phi::MetaTensor {
void share_meta(const MetaTensor& meta_tensor) override { void share_meta(const MetaTensor& meta_tensor) override {
share_dims(meta_tensor); share_dims(meta_tensor);
set_dtype(meta_tensor.dtype()); set_dtype(meta_tensor.dtype());
// VarDesc doesn't contains layout, so we cannot share layout set_layout(meta_tensor.layout());
// set_layout(meta_tensor.layout());
// special case: share lod of LoDTensor // special case: share lod of LoDTensor
share_lod(meta_tensor); share_lod(meta_tensor);
} }
......
...@@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2103,16 +2103,25 @@ void OperatorWithKernel::BuildPhiKernelContext(
auto* var = ins_vector[offset]; auto* var = ins_vector[offset];
if (var->IsType<framework::LoDTensor>()) { if (var->IsType<framework::LoDTensor>()) {
tensor_in = &(var->Get<framework::LoDTensor>()); tensor_in = &(var->Get<framework::LoDTensor>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
tensor_in = &(var->Get<phi::SelectedRows>()); tensor_in = &(var->Get<phi::SelectedRows>());
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var->Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackInputsWithoutSetRange(tensor_vector);
end_idx += tensor_array.size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.", "Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
pt_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} }
// Note: here cannot deal with vector<LoDTensorArray> input
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(4) << "Done inputs"; VLOG(4) << "Done inputs";
...@@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2140,22 +2149,33 @@ void OperatorWithKernel::BuildPhiKernelContext(
for (size_t offset = 0; offset < outs_vector.size(); ++offset) { for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* tensor_out = nullptr; phi::TensorBase* tensor_out = nullptr;
auto* var = outs_vector[offset]; auto* var = outs_vector[offset];
if (var) { if (var) {
if (var->template IsType<framework::LoDTensor>()) { if (var->template IsType<framework::LoDTensor>()) {
tensor_out = var->template GetMutable<framework::LoDTensor>(); tensor_out = var->template GetMutable<framework::LoDTensor>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SelectedRows>()) { } else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>(); tensor_out = var->template GetMutable<phi::SelectedRows>();
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<phi::TensorBase*> tensor_vector;
auto* tensor_array =
var->template GetMutable<framework::LoDTensorArray>();
// Note: If the input LoDTensorArray size is 0, the output
// LoDTensorArray is also 0
for (auto& t : *tensor_array) {
tensor_vector.emplace_back(&t);
}
pt_kernel_context->EmplaceBackOutputsWithoutSetRange(tensor_vector);
end_idx += tensor_array->size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
} else {
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
pt_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i); pt_kernel_context->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
VLOG(4) << "Done outputs"; VLOG(4) << "Done outputs";
......
...@@ -483,6 +483,10 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -483,6 +483,10 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
return ctx_.InputVar(name)->IsType<phi::SelectedRows>(); return ctx_.InputVar(name)->IsType<phi::SelectedRows>();
} }
bool IsDenseTensorVectorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::LoDTensorArray>();
}
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
return ctx_.OutputVar(name)->IsType<framework::LoDTensor>(); return ctx_.OutputVar(name)->IsType<framework::LoDTensor>();
} }
......
...@@ -289,14 +289,23 @@ void BuildDygraphPhiKernelContext( ...@@ -289,14 +289,23 @@ void BuildDygraphPhiKernelContext(
auto& var = ins_vector[offset]->Var(); auto& var = ins_vector[offset]->Var();
if (var.template IsType<phi::DenseTensor>()) { if (var.template IsType<phi::DenseTensor>()) {
tensor_in = &(var.template Get<phi::DenseTensor>()); tensor_in = &(var.template Get<phi::DenseTensor>());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var.template IsType<phi::SelectedRows>()) { } else if (var.template IsType<phi::SelectedRows>()) {
tensor_in = &(var.template Get<phi::SelectedRows>()); tensor_in = &(var.template Get<phi::SelectedRows>());
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var.template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<const phi::TensorBase*> tensor_vector;
auto& tensor_array = var.template Get<framework::LoDTensorArray>();
for (auto& t : tensor_array) {
tensor_vector.emplace_back(&t);
}
kernel_ctx->EmplaceBackInputsWithoutSetRange(tensor_vector);
end_idx += tensor_array.size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported input `%s` type when call pt kernel.", "Unsupported input `%s` type when call pt kernel.",
framework::ToTypeName(var.Type()))); framework::ToTypeName(var.Type())));
} }
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} }
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
...@@ -326,16 +335,27 @@ void BuildDygraphPhiKernelContext( ...@@ -326,16 +335,27 @@ void BuildDygraphPhiKernelContext(
if (var) { if (var) {
if (var->template IsType<phi::DenseTensor>()) { if (var->template IsType<phi::DenseTensor>()) {
tensor_out = var->template GetMutable<phi::DenseTensor>(); tensor_out = var->template GetMutable<phi::DenseTensor>();
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SelectedRows>()) { } else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>(); tensor_out = var->template GetMutable<phi::SelectedRows>();
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
paddle::SmallVector<phi::TensorBase*> tensor_vector;
auto* tensor_array =
var->template GetMutable<framework::LoDTensorArray>();
for (auto& t : *tensor_array) {
tensor_vector.emplace_back(&t);
}
kernel_ctx->EmplaceBackOutputsWithoutSetRange(tensor_vector);
end_idx += tensor_array->size() - 1;
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported output `%s` type when call pt kernel.", "Unsupported output `%s` type when call pt kernel.",
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
} }
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
kernel_ctx->EmplaceBackOutputWithoutSetRange(tensor_out);
} }
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignOutputRange(std::make_pair(start_idx, end_idx), i);
} }
......
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDesc; class OpDesc;
...@@ -36,26 +39,6 @@ class AssignOp : public framework::OperatorWithKernel { ...@@ -36,26 +39,6 @@ class AssignOp : public framework::OperatorWithKernel {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
if (ctx->HasInput("X")) {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
if (ctx->IsRuntime()) {
// The runtime output shape is determined in kernel.
return;
} else {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
}
}
}
}
protected: protected:
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
...@@ -91,24 +74,6 @@ class AssignInferVarType : public framework::VarTypeInference { ...@@ -91,24 +74,6 @@ class AssignInferVarType : public framework::VarTypeInference {
} }
}; };
class AssignKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.InputVar("X");
if (x == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of assign_op is not found."));
auto *out = ctx.OutputVar("Out");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
}
};
class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker { class AssignOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -147,23 +112,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"}); ...@@ -147,23 +112,11 @@ DECLARE_INPLACE_OP_INFERER(AssignOpInplaceInferer, {"X", "Out"});
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(assign, AssignInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(assign, ops::AssignOp, REGISTER_OPERATOR(assign, ops::AssignOp,
ops::AssignGradMaker<paddle::framework::OpDesc>, ops::AssignGradMaker<paddle::framework::OpDesc>,
ops::AssignGradMaker<paddle::imperative::OpBase>, ops::AssignGradMaker<paddle::imperative::OpBase>,
ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer, ops::AssignOpProtoMaker, ops::AssignOpInplaceInferer,
ops::AssignInferVarType); ops::AssignInferVarType, AssignInferShapeFunctor);
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, uint8_t,
ops::AssignKernel, bool, ops::AssignKernel,
plat::float16, ops::AssignKernel, plat::bfloat16,
ops::AssignKernel);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops::AssignKernel, int, ops::AssignKernel,
int64_t, ops::AssignKernel, uint8_t,
ops::AssignKernel, bool, ops::AssignKernel,
plat::float16, ops::AssignKernel);
#endif
...@@ -29,7 +29,7 @@ limitations under the License. */ ...@@ -29,7 +29,7 @@ limitations under the License. */
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(assign); USE_OP_ITSELF(assign);
USE_OP_DEVICE_KERNEL(assign, NPU); USE_OP_DEVICE_KERNEL(assign, NPU);
template <typename T> template <typename T>
......
...@@ -60,6 +60,10 @@ bool ProtoArgumentMappingContext::IsSelectedRowsInput( ...@@ -60,6 +60,10 @@ bool ProtoArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const { const std::string& name) const {
return false; return false;
} }
bool ProtoArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
}
bool ProtoArgumentMappingContext::IsDenseTensorOutput( bool ProtoArgumentMappingContext::IsDenseTensorOutput(
const std::string& name) const { const std::string& name) const {
......
...@@ -42,6 +42,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext { ...@@ -42,6 +42,7 @@ class ProtoArgumentMappingContext : public ::phi::ArgumentMappingContext {
bool IsDenseTensorInput(const std::string& name) const override; bool IsDenseTensorInput(const std::string& name) const override;
bool IsSelectedRowsInput(const std::string& name) const override; bool IsSelectedRowsInput(const std::string& name) const override;
bool IsDenseTensorVectorInput(const std::string& name) const override;
bool IsDenseTensorOutput(const std::string& name) const override; bool IsDenseTensorOutput(const std::string& name) const override;
bool IsSelectedRowsOutput(const std::string& name) const override; bool IsSelectedRowsOutput(const std::string& name) const override;
......
...@@ -89,6 +89,8 @@ class ArgumentMappingContext { ...@@ -89,6 +89,8 @@ class ArgumentMappingContext {
virtual bool IsDenseTensorInput(const std::string& name) const = 0; virtual bool IsDenseTensorInput(const std::string& name) const = 0;
virtual bool IsSelectedRowsInput(const std::string& name) const = 0; virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
// For compatibility with LoDTensorArray
virtual bool IsDenseTensorVectorInput(const std::string& name) const = 0;
virtual bool IsDenseTensorOutput(const std::string& name) const = 0; virtual bool IsDenseTensorOutput(const std::string& name) const = 0;
virtual bool IsSelectedRowsOutput(const std::string& name) const = 0; virtual bool IsSelectedRowsOutput(const std::string& name) const = 0;
......
...@@ -37,6 +37,13 @@ void KernelContext::EmplaceBackInputs( ...@@ -37,6 +37,13 @@ void KernelContext::EmplaceBackInputs(
std::make_move_iterator(inputs.end())); std::make_move_iterator(inputs.end()));
} }
void KernelContext::EmplaceBackInputsWithoutSetRange(
paddle::SmallVector<const TensorBase*> inputs) {
inputs_.insert(inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void KernelContext::EmplaceBackOutput(TensorBase* output) { void KernelContext::EmplaceBackOutput(TensorBase* output) {
int index = outputs_.size(); int index = outputs_.size();
outputs_.emplace_back(output); outputs_.emplace_back(output);
...@@ -59,6 +66,13 @@ void KernelContext::EmplaceBackOutputs( ...@@ -59,6 +66,13 @@ void KernelContext::EmplaceBackOutputs(
std::make_move_iterator(outputs.end())); std::make_move_iterator(outputs.end()));
} }
void KernelContext::EmplaceBackOutputsWithoutSetRange(
paddle::SmallVector<TensorBase*> outputs) {
outputs_.insert(outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
void KernelContext::EmplaceBackAttr(paddle::any attr) { void KernelContext::EmplaceBackAttr(paddle::any attr) {
attrs_.emplace_back(std::move(attr)); attrs_.emplace_back(std::move(attr));
} }
......
...@@ -52,12 +52,18 @@ class KernelContext { ...@@ -52,12 +52,18 @@ class KernelContext {
void EmplaceBackInputs(paddle::SmallVector<const TensorBase*> inputs); void EmplaceBackInputs(paddle::SmallVector<const TensorBase*> inputs);
void EmplaceBackInputsWithoutSetRange(
paddle::SmallVector<const TensorBase*> inputs);
void EmplaceBackOutput(TensorBase* output); void EmplaceBackOutput(TensorBase* output);
void EmplaceBackOutputWithoutSetRange(TensorBase* output); void EmplaceBackOutputWithoutSetRange(TensorBase* output);
void EmplaceBackOutputs(paddle::SmallVector<TensorBase*> outputs); void EmplaceBackOutputs(paddle::SmallVector<TensorBase*> outputs);
void EmplaceBackOutputsWithoutSetRange(
paddle::SmallVector<TensorBase*> outputs);
void EmplaceBackAttr(paddle::any attr); void EmplaceBackAttr(paddle::any attr);
const std::pair<int, int>& InputRangeAt(size_t idx) const; const std::pair<int, int>& InputRangeAt(size_t idx) const;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/assign_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/utils/optional.h"
namespace phi {
template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out) {
if (!x.is_initialized()) {
return;
}
auto& x_tensor = *x.get_ptr();
Copy<Context>(dev_ctx, x_tensor, x_tensor.place(), false, out);
}
// Note: use `const paddle::optional<std::vector<const DenseTensor*>&> x`
// as input if needed
template <typename Context>
void AssignArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
std::vector<DenseTensor*> out) {
for (size_t i = 0; i < x.size(); ++i) {
AssignKernel<Context>(dev_ctx, *x[i], out.at(i));
}
}
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_array,
CPU,
ALL_LAYOUT,
phi::AssignArrayKernel<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_array,
GPU,
ALL_LAYOUT,
phi::AssignArrayKernel<phi::GPUContext>,
ALL_DTYPE) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
// In order to be compatible with the `AsDispensable` input in the original
// assign op maker, the input parameter here needs to be dispensable, but
// this looks weird
template <typename Context>
void AssignKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out);
template <typename Context>
void AssignArrayKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& x,
std::vector<DenseTensor*> out);
} // namespace phi
...@@ -38,7 +38,7 @@ void Copy(const Context& dev_ctx, ...@@ -38,7 +38,7 @@ void Copy(const Context& dev_ctx,
<< src_place; << src_place;
dst->Resize(src.dims()); dst->Resize(src.dims());
auto* dst_ptr = dev_ctx.Alloc(dst, src.dtype()); auto* dst_ptr = dev_ctx.HostAlloc(dst, src.dtype());
if (src_ptr == dst_ptr) { if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to " VLOG(3) << "Skip copy the same data async from " << src_place << " to "
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/assign_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/assign_kernel.h"
namespace phi {
namespace sr {
// Note: use `const paddle::optional<const SelectedRows&> x`
// as input if needed
template <typename Context>
void AssignKernel(const Context& dev_ctx,
const SelectedRows& x,
SelectedRows* out) {
out->set_rows(x.rows());
out->set_height(x.height());
phi::AssignKernel<Context>(dev_ctx, x.value(), out->mutable_value());
}
} // namespace sr
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(assign_sr,
CPU,
ALL_LAYOUT,
phi::sr::AssignKernel<phi::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(assign_sr,
GPU,
ALL_LAYOUT,
phi::sr::AssignKernel<phi::GPUContext>,
ALL_DTYPE) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename Context>
void AssignKernel(const Context& dev_ctx,
const SelectedRows& x,
SelectedRows* out);
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("X")) {
if (ctx.IsDenseTensorVectorInput("X")) {
return KernelSignature("assign_array", {"X"}, {}, {"Out"});
} else if (ctx.IsSelectedRowsInput("X")) {
return KernelSignature("assign_sr", {"X"}, {}, {"Out"});
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
}
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(assign, phi::AssignOpArgumentMapping);
...@@ -61,6 +61,10 @@ TEST(DEV_API, copy) { ...@@ -61,6 +61,10 @@ TEST(DEV_API, copy) {
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace()) .GetAllocator(paddle::platform::CPUPlace())
.get()); .get());
dev_ctx.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init(); dev_ctx.Init();
phi::Copy( phi::Copy(
dev_ctx, *(dense_src.get()), phi::CPUPlace(), false, dense_dst.get()); dev_ctx, *(dense_src.get()), phi::CPUPlace(), false, dense_dst.get());
......
...@@ -58,6 +58,10 @@ TEST(DEV_API, flatten) { ...@@ -58,6 +58,10 @@ TEST(DEV_API, flatten) {
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace()) .GetAllocator(paddle::platform::CPUPlace())
.get()); .get());
dev_ctx.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init(); dev_ctx.Init();
// 2. test API // 2. test API
......
...@@ -50,6 +50,10 @@ TEST(DEV_API, reshape) { ...@@ -50,6 +50,10 @@ TEST(DEV_API, reshape) {
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace()) .GetAllocator(paddle::platform::CPUPlace())
.get()); .get());
dev_ctx.SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init(); dev_ctx.Init();
auto out = phi::Reshape<float>(dev_ctx, dense_x, shape); auto out = phi::Reshape<float>(dev_ctx, dense_x, shape);
// 3. check result // 3. check result
......
...@@ -72,6 +72,11 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext { ...@@ -72,6 +72,11 @@ class TestArgumentMappingContext : public phi::ArgumentMappingContext {
return selected_rows_inputs.count(name) > 0; return selected_rows_inputs.count(name) > 0;
} }
// add member if needed
bool IsDenseTensorVectorInput(const std::string& name) const override {
return false;
}
bool IsDenseTensorOutput(const std::string& name) const override { bool IsDenseTensorOutput(const std::string& name) const override {
return dense_tensor_outputs.count(name) > 0; return dense_tensor_outputs.count(name) > 0;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册