未验证 提交 a981af15 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] support more elementwise cases (#3133)

* [NPU] revert shape check for input&output

* [NPU] reshape elt input&output
上级 66f0b25b
......@@ -21,42 +21,78 @@ namespace lite {
namespace subgraph {
namespace npu {
void CvtYShape(std::vector<int64_t>* x_shape,
std::vector<int64_t>* y_shape,
int axis) {
CHECK_GE(x_shape->size(), y_shape->size());
void CvtXYShape(std::vector<int64_t>* x_shape,
std::vector<int64_t>* y_shape,
int axis) {
int x_shape_size = x_shape->size();
int y_shape_size = y_shape->size();
CHECK_GE(x_shape_size, y_shape_size);
if (axis < 0) {
axis = x_shape->size() - y_shape->size();
// only support:
// 1. same shape
// 2. (n,c,h,w) * (1,c,1,1)
// 3. (n,c,h,w) * (n,c,1,1)
// 4. (n,c,h,w) * (1,c,h,1)
// 5. (n,c,h,w) * (1,c,h,w)
// 6. (n,c,h,w) * (n,c,1,w)
if (*x_shape == *y_shape) {
*x_shape = CvtShape(*x_shape);
*y_shape = CvtShape(*y_shape);
return;
}
// only support:
// (n,c,h,w) * (n,c,h,w)
// (n,c,h,w) * (1,c,1,1)
// (n,c,h,w) * (1,c,h,1)
// (n,c,h,w) * (1,c,h,w)
int y_shape_size = y_shape->size();
if (y_shape_size == 1) {
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 2, 1);
} else if (y_shape_size == 2) {
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
} else if (y_shape_size == 3) {
y_shape->insert(y_shape->begin(), 1);
for (int i = 0; i < 4 - x_shape_size; i++) {
x_shape->push_back(1);
}
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
*x_shape = std::vector<int64_t>{1, n, c * h * w, 1};
} else if (axis == 2) {
*x_shape = std::vector<int64_t>{n * c, h, w, 1};
} else if (axis == 3) {
*x_shape = std::vector<int64_t>{n * c * h, w, 1, 1};
}
*y_shape = std::vector<int64_t>{1, y_shape->at(0), 1, 1};
return;
}
if (y_shape_size < 4) {
int n = 1;
for (int i = 0; i < axis; i++) {
n *= x_shape->at(i);
if (y_shape_size == 2) {
for (int i = 0; i < 4 - x_shape_size; i++) {
x_shape->push_back(1);
}
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
y_shape->insert(y_shape->end(), 2, 1);
} else if (axis == 1) {
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
} else if (axis == 2) {
*x_shape = std::vector<int64_t>{n * c, h, w, 1};
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
}
x_shape->erase(x_shape->begin(), x_shape->begin() + axis);
x_shape->insert(x_shape->begin(), n);
x_shape->insert(x_shape->end(), 4 - x_shape->size(), 1);
return;
}
CHECK_EQ(x_shape->size(), 4UL);
CHECK_EQ(y_shape->size(), 4UL);
if (y_shape_size == 3) {
y_shape->insert(y_shape->begin(), 1);
int64_t n = x_shape->at(0);
int64_t c = x_shape->at(1);
int64_t h = x_shape->at(2);
int64_t w = x_shape->at(3);
if (axis == 0) {
*x_shape = std::vector<int64_t>{1, n * c * h, w, 1};
*y_shape = std::vector<int64_t>{1, n * c * h, 1, 1};
}
return;
}
}
int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
......@@ -70,36 +106,37 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto x = scope->FindTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y = scope->FindMutableTensor(y_name);
auto y = scope->FindTensor(y_name);
auto y_dims = y->dims();
auto out_name = op_info->Output("Out").front();
auto out = scope->FindMutableTensor(out_name);
auto out = scope->FindTensor(out_name);
auto out_dims = out->dims();
auto axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
auto x_new_shape = x_dims.Vectorize();
auto y_new_shape = y_dims.Vectorize();
CvtYShape(&x_new_shape, &y_new_shape, axis);
CvtXYShape(&x_new_shape, &y_new_shape, axis);
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
if (x_dims.Vectorize() != x_new_shape) {
auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_tensor(*x_node->data());
reshaped_x_op->set_attr_shape(
ge::AttrValue::LIST_INT(x_new_shape.begin(), x_new_shape.end()));
reshaped_x_op->set_attr_axis(0);
x_node = reshaped_x_node;
}
auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_tensor(*x_node->data());
reshaped_x_op->set_attr_shape(
ge::AttrValue::LIST_INT(x_new_shape.begin(), x_new_shape.end()));
reshaped_x_op->set_attr_axis(0);
x_node = reshaped_x_node;
} else {
x_node = graph->Add(x_name, *x, x_new_shape);
}
......@@ -108,15 +145,13 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
if (y_dims.Vectorize() != y_new_shape) {
auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_op->set_input_tensor(*y_node->data());
reshaped_y_op->set_attr_shape(
ge::AttrValue::LIST_INT(y_new_shape.begin(), y_new_shape.end()));
reshaped_y_op->set_attr_axis(0);
y_node = reshaped_y_node;
}
auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_op->set_input_tensor(*y_node->data());
reshaped_y_op->set_attr_shape(
ge::AttrValue::LIST_INT(y_new_shape.begin(), y_new_shape.end()));
reshaped_y_op->set_attr_axis(0);
y_node = reshaped_y_node;
} else {
y_node = graph->Add(y_name, *y, y_new_shape);
}
......@@ -152,11 +187,11 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
return FAILED;
}
if (out_dims.Vectorize() != x_new_shape) {
auto out_shape = out_dims.Vectorize();
if (out_shape != x_new_shape) {
auto reshaped_elt_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_elt_op = reshaped_elt_node->data<ge::op::Reshape>();
reshaped_elt_op->set_input_tensor(*elt_node->data());
auto out_shape = out_dims.Vectorize();
reshaped_elt_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
reshaped_elt_op->set_attr_axis(0);
......
......@@ -151,15 +151,6 @@ int CvtActMode(std::string act_type) {
return act_mode;
}
bool CheckShape(DDim origin_dims, hiai::TensorDimension device_dims) {
auto origin_shape = CvtShape(origin_dims);
CHECK_EQ(origin_shape.size(), 4);
return origin_shape[0] == device_dims.GetNumber() &&
origin_shape[1] == device_dims.GetChannel() &&
origin_shape[2] == device_dims.GetHeight() &&
origin_shape[3] == device_dims.GetWidth();
}
} // namespace npu
} // namespace subgraph
} // namespace lite
......
......@@ -19,7 +19,6 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "HiAiModelManagerService.h"
#include "graph/buffer.h"
#include "graph/graph.h"
#include "graph/model.h"
......@@ -146,8 +145,6 @@ ge::TensorPtr CvtTensor(const Tensor& in_tensor,
int CvtActMode(std::string act_type);
bool CheckShape(DDim origin_dims, hiai::TensorDimension device_dims);
} // namespace npu
} // namespace subgraph
} // namespace lite
......
......@@ -124,19 +124,9 @@ int SubgraphEngine::BuildDeviceProgram() {
<< device_idims[i].GetHeight() << "," << device_idims[i].GetWidth()
<< "}";
// Prepare the device input tensors
if (!subgraph::npu::CheckShape(origin_idims_[i], device_idims[i])) {
LOG(WARNING) << "origin and device input's dims are mismatched.";
for (int j = 0; j < origin_idims_[i].size(); j++) {
LOG(WARNING) << "origin_idims_[" << i << "][" << j
<< "]: " << origin_idims_[i][j];
}
LOG(WARNING) << "device_idims[" << i << "]: {"
<< device_idims[i].GetNumber() << ", "
<< device_idims[i].GetChannel() << ", "
<< device_idims[i].GetHeight() << ", "
<< device_idims[i].GetWidth() << "}";
return subgraph::FAILED;
}
CHECK_EQ(origin_idims_[i].production(),
device_idims[i].GetNumber() * device_idims[i].GetChannel() *
device_idims[i].GetHeight() * device_idims[i].GetWidth());
device_itensors_[i].reset(new hiai::AiTensor);
device_itensors_[i]->Init(&(device_idims[i]));
}
......@@ -177,21 +167,9 @@ int SubgraphEngine::BuildDeviceProgram() {
<< PrecisionToStr(precision);
break;
}
/*
if (!subgraph::npu::CheckShape(origin_odims_[i], device_odims[i])) {
LOG(WARNING) << "origin and device output's dims are mismatched.";
for (int j = 0; j < origin_odims_[i].size(); j++) {
LOG(WARNING) << "origin_odims_[" << i << "][" << j
<< "]: " << origin_odims_[i][j];
}
LOG(WARNING) << "device_odims[" << i << "]: {"
<< device_odims[i].GetNumber() << ", "
<< device_odims[i].GetChannel() << ", "
<< device_odims[i].GetHeight() << ", "
<< device_odims[i].GetWidth() << "}";
return subgraph::FAILED;
}
*/
CHECK_EQ(origin_odims_[i].production(),
device_odims[i].GetNumber() * device_odims[i].GetChannel() *
device_odims[i].GetHeight() * device_odims[i].GetWidth());
device_otensors_[i].reset(new hiai::AiTensor);
device_otensors_[i]->Init(&(device_odims[i]));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册