未验证 提交 919282cf 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] support 1-dimension input y in elemetwise ops (#2546)

test=develop
上级 1b875ae8
...@@ -21,6 +21,30 @@ namespace kernels { ...@@ -21,6 +21,30 @@ namespace kernels {
namespace npu { namespace npu {
namespace bridges { namespace bridges {
std::vector<int64_t> CvtYShape(const Tensor& x, Tensor* y, int axis) {
auto x_dims = x.dims();
CHECK_EQ(x_dims.size(), 4UL) << "[NPU] only support 4-dimension x";
auto y_dims = y->dims();
CHECK_GE(x_dims.size(), y_dims.size());
if (axis < 0) {
axis += x_dims.size();
}
std::vector<int64_t> y_new_shape(y_dims.Vectorize());
if (y_new_shape.size() == 4UL) {
return y_new_shape;
}
for (int i = 0; i < axis; i++) {
y_new_shape.insert(y_new_shape.begin(), 1);
}
while (y_new_shape.size() < 4) {
y_new_shape.push_back(1);
}
CHECK_EQ(y_new_shape.size(), 4UL);
return y_new_shape;
}
node_map_type ElementwiseConverter( node_map_type ElementwiseConverter(
const std::shared_ptr<lite::OpLite> elementwise_op, const std::shared_ptr<lite::OpLite> elementwise_op,
const node_map_type& inputs_map) { const node_map_type& inputs_map) {
...@@ -33,6 +57,7 @@ node_map_type ElementwiseConverter( ...@@ -33,6 +57,7 @@ node_map_type ElementwiseConverter(
auto x_var_name = op_info->Input("X").front(); auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front(); auto y_var_name = op_info->Input("Y").front();
CHECK(inputs_map.find(x_var_name) != inputs_map.end()); CHECK(inputs_map.find(x_var_name) != inputs_map.end());
auto axis = op_info->GetAttr<int>("axis");
std::shared_ptr<ge::Operator> elementwise_node = nullptr; std::shared_ptr<ge::Operator> elementwise_node = nullptr;
std::shared_ptr<ge::Operator> x_node = inputs_map.at(x_var_name); std::shared_ptr<ge::Operator> x_node = inputs_map.at(x_var_name);
...@@ -41,8 +66,10 @@ node_map_type ElementwiseConverter( ...@@ -41,8 +66,10 @@ node_map_type ElementwiseConverter(
y_node = inputs_map.at(y_var_name); y_node = inputs_map.at(y_var_name);
} else { } else {
auto y_const_node = std::make_shared<ge::op::Const>(y_var_name); auto y_const_node = std::make_shared<ge::op::Const>(y_var_name);
auto* y = scope->FindMutableTensor(y_var_name); auto x = scope->FindTensor(x_var_name);
y_const_node->set_attr_value(lite::npu::CvtTensor(y)); auto y = scope->FindMutableTensor(y_var_name);
auto y_new_shape = CvtYShape(*x, y, axis);
y_const_node->set_attr_value(lite::npu::CvtTensor(y, y_new_shape));
y_node = y_const_node; y_node = y_const_node;
} }
lite::npu::OpList::Global().add(x_node); lite::npu::OpList::Global().add(x_node);
......
...@@ -45,7 +45,7 @@ void elementwise_add_ref(const std::shared_ptr<operators::ElementwiseOp> op) { ...@@ -45,7 +45,7 @@ void elementwise_add_ref(const std::shared_ptr<operators::ElementwiseOp> op) {
if (axis < 0) { if (axis < 0) {
axis += x_dims.size(); axis += x_dims.size();
} }
int batch = x_dims[0] / y_dims[0]; int batch = 1;
int channels = y->numel(); int channels = y->numel();
int num = x->numel() / channels / batch; int num = x->numel() / channels / batch;
// do elementwise add/sub/max... // do elementwise add/sub/max...
...@@ -143,8 +143,8 @@ void test_elementwise_add(const std::vector<int64_t>& x_shape, ...@@ -143,8 +143,8 @@ void test_elementwise_add(const std::vector<int64_t>& x_shape,
y->Resize(y_shape); y->Resize(y_shape);
// initialize input&output data // initialize input&output data
FillTensor<float>(x, 1, 5); FillTensor<float>(x, 1, 3);
FillTensor<float>(y, 1, 5); FillTensor<float>(y, 1, 3);
// initialize op desc // initialize op desc
cpp::OpDesc opdesc; cpp::OpDesc opdesc;
...@@ -171,6 +171,7 @@ void test_elementwise_add(const std::vector<int64_t>& x_shape, ...@@ -171,6 +171,7 @@ void test_elementwise_add(const std::vector<int64_t>& x_shape,
TEST(NPUBridges, elementwise_add) { TEST(NPUBridges, elementwise_add) {
for (auto elt_type : {"add", "sub", "mul", "div"}) { for (auto elt_type : {"add", "sub", "mul", "div"}) {
test_elementwise_add({1, 2, 3, 4}, {2}, 1, elt_type);
test_elementwise_add({1, 2, 3, 4}, {1, 2, 1, 1}, 1, elt_type); test_elementwise_add({1, 2, 3, 4}, {1, 2, 1, 1}, 1, elt_type);
test_elementwise_add({1, 2, 3, 4}, {1, 2, 3, 4}, 3, elt_type); test_elementwise_add({1, 2, 3, 4}, {1, 2, 3, 4}, 3, elt_type);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册