未验证 提交 cbf22d65 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] NpuOpRunner supports host tensor as input (#33992)

* NpuOpRunner supports host tensor as input

* fix compile issue
上级 20da7703
......@@ -39,14 +39,14 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
table_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument("npu only accept LoDTensor"));
output_t->mutable_data<T>(ctx.GetPlace());
framework::NPUAttributeMap attr_input = {{"validate_indices", false}};
const auto &runner =
NpuOpRunner("Gather", {*table_t, *ids_t}, {*output_t}, attr_input);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("GatherV2")
.AddInput(*table_t)
.AddInput(*ids_t)
.AddInput(std::vector<int32_t>{0})
.AddOutput(*output_t);
runner.Run();
}
};
......
......@@ -74,15 +74,15 @@ aclrtStream GetCurrentNPUStream(int device_id) {
return dev_ctx->stream();
}
NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) {
attr_ = aclopCreateAttr();
}
NpuOpRunner::NpuOpRunner() {}
NpuOpRunner::NpuOpRunner(std::string op_type, const std::vector<Tensor> &inputs,
NpuOpRunner::NpuOpRunner(const std::string &op_type) : op_type_(op_type) {}
NpuOpRunner::NpuOpRunner(const std::string &op_type,
const std::vector<Tensor> &inputs,
const std::vector<Tensor> &outputs,
const NPUAttributeMap &attrs)
: op_type_(op_type) {
attr_ = aclopCreateAttr();
AddInputs(inputs);
AddOutputs(outputs);
AddAttrs(attrs);
......@@ -108,8 +108,16 @@ NpuOpRunner::~NpuOpRunner() {
const std::string &NpuOpRunner::Type() { return op_type_; }
NpuOpRunner &NpuOpRunner::SetType(const std::string &name) {
op_type_ = name;
return *this;
}
NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
const NPUAttribute &attr) {
if (!attr_) {
attr_ = aclopCreateAttr();
}
if (attr.type() == typeid(bool)) {
PADDLE_ENFORCE_NPU_SUCCESS(
aclopSetAttrBool(attr_, name.c_str(), BOOST_GET_CONST(bool, attr)));
......@@ -191,6 +199,46 @@ NpuOpRunner &NpuOpRunner::AddInput(const Tensor &tensor) {
return *this;
}
NpuOpRunner &NpuOpRunner::AddInput(const Tensor &tensor, aclMemType mem_type) {
// create aclTensorDesc
input_descs_.emplace_back(CreateTensorDesc(tensor, mem_type));
// create aclDataBuffer
input_buffers_.emplace_back(CreateDataBuffer(tensor));
return *this;
}
NpuOpRunner &NpuOpRunner::AddInput(std::vector<int32_t> &&dims) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx =
static_cast<platform::CPUDeviceContext *>(pool.Get(platform::CPUPlace()));
Tensor host_tensor;
TensorFromVector(dims, *dev_ctx, &host_tensor);
host_tensors_.emplace_back(host_tensor);
// create aclTensorDesc
input_descs_.emplace_back(CreateTensorDesc(host_tensor, ACL_MEMTYPE_HOST));
// create aclDataBuffer
input_buffers_.emplace_back(CreateDataBuffer(host_tensor));
return *this;
}
NpuOpRunner &NpuOpRunner::AddInput(std::vector<int64_t> &&dims) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx =
static_cast<platform::CPUDeviceContext *>(pool.Get(platform::CPUPlace()));
Tensor host_tensor;
TensorFromVector(dims, *dev_ctx, &host_tensor);
host_tensors_.emplace_back(host_tensor);
// create aclTensorDesc
input_descs_.emplace_back(CreateTensorDesc(host_tensor, ACL_MEMTYPE_HOST));
// create aclDataBuffer
input_buffers_.emplace_back(CreateDataBuffer(host_tensor));
return *this;
}
NpuOpRunner &NpuOpRunner::AddOutput(const Tensor &tensor) {
// create aclTensorDesc
output_descs_.emplace_back(CreateTensorDesc(tensor));
......@@ -272,7 +320,8 @@ std::vector<aclDataBuffer *> &NpuOpRunner::GetOutputBuffers() {
return output_buffers_;
}
aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor,
aclMemType mem_type) {
auto dtype = ConvertToNpuDtype(tensor.type());
auto format = ConvertToNpuFormat(tensor.layout());
auto dims = framework::vectorize(tensor.dims());
......@@ -287,6 +336,9 @@ aclTensorDesc *NpuOpRunner::CreateTensorDesc(Tensor tensor) {
PADDLE_ENFORCE_NPU_SUCCESS(aclSetTensorStorageFormat(desc, format));
PADDLE_ENFORCE_NPU_SUCCESS(
aclSetTensorStorageShape(desc, dims.size(), dims.data()));
if (mem_type == ACL_MEMTYPE_HOST) {
PADDLE_ENFORCE_NPU_SUCCESS(aclSetTensorPlaceMent(desc, mem_type));
}
return desc;
}
......
......@@ -35,8 +35,9 @@ using DeviceContextPool = platform::DeviceContextPool;
class NpuOpRunner {
public:
explicit NpuOpRunner(std::string op_type);
explicit NpuOpRunner(std::string op_type,
NpuOpRunner();
explicit NpuOpRunner(const std::string &op_type);
NpuOpRunner(const std::string &op_type,
const std::vector<Tensor> &inputs = {},
const std::vector<Tensor> &outputs = {},
const NPUAttributeMap &attrs = {});
......@@ -53,12 +54,23 @@ class NpuOpRunner {
const std::string &Type();
NpuOpRunner &SetType(const std::string &name);
NpuOpRunner &AddAttr(const std::string &name, const NPUAttribute &attr);
NpuOpRunner &AddAttrs(const NPUAttributeMap &attrs);
NpuOpRunner &AddInput(const Tensor &tensor);
// NOTE(zhiqiu): CANN-5.0.2 support input tensors on host.
// Specifically, the tensor of shape, tensor of dims, etc, which are are small
// vector/list.
NpuOpRunner &AddInput(const Tensor &tensor, aclMemType mem_type);
NpuOpRunner &AddInput(std::vector<int32_t> &&dims);
NpuOpRunner &AddInput(std::vector<int64_t> &&dims);
NpuOpRunner &AddOutput(const Tensor &tensor);
NpuOpRunner &AddInputs(const std::vector<Tensor> &tensors);
......@@ -82,7 +94,8 @@ class NpuOpRunner {
void Run(aclrtStream stream = nullptr) const;
private:
aclTensorDesc *CreateTensorDesc(Tensor tensor);
aclTensorDesc *CreateTensorDesc(Tensor tensor,
aclMemType mem_type = ACL_MEMTYPE_DEVICE);
aclDataBuffer *CreateDataBuffer(Tensor tensor);
private:
......@@ -91,6 +104,7 @@ class NpuOpRunner {
std::vector<aclDataBuffer *> output_buffers_;
std::vector<aclTensorDesc *> input_descs_;
std::vector<aclTensorDesc *> output_descs_;
std::vector<Tensor> host_tensors_;
aclopAttr *attr_{nullptr};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册