未验证 提交 919282cf 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] support 1-dimension input y in elemetwise ops (#2546)

test=develop
上级 1b875ae8
......@@ -21,6 +21,30 @@ namespace kernels {
namespace npu {
namespace bridges {
std::vector<int64_t> CvtYShape(const Tensor& x, Tensor* y, int axis) {
auto x_dims = x.dims();
CHECK_EQ(x_dims.size(), 4UL) << "[NPU] only support 4-dimension x";
auto y_dims = y->dims();
CHECK_GE(x_dims.size(), y_dims.size());
if (axis < 0) {
axis += x_dims.size();
}
std::vector<int64_t> y_new_shape(y_dims.Vectorize());
if (y_new_shape.size() == 4UL) {
return y_new_shape;
}
for (int i = 0; i < axis; i++) {
y_new_shape.insert(y_new_shape.begin(), 1);
}
while (y_new_shape.size() < 4) {
y_new_shape.push_back(1);
}
CHECK_EQ(y_new_shape.size(), 4UL);
return y_new_shape;
}
node_map_type ElementwiseConverter(
const std::shared_ptr<lite::OpLite> elementwise_op,
const node_map_type& inputs_map) {
......@@ -33,6 +57,7 @@ node_map_type ElementwiseConverter(
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
CHECK(inputs_map.find(x_var_name) != inputs_map.end());
auto axis = op_info->GetAttr<int>("axis");
std::shared_ptr<ge::Operator> elementwise_node = nullptr;
std::shared_ptr<ge::Operator> x_node = inputs_map.at(x_var_name);
......@@ -41,8 +66,10 @@ node_map_type ElementwiseConverter(
y_node = inputs_map.at(y_var_name);
} else {
auto y_const_node = std::make_shared<ge::op::Const>(y_var_name);
auto* y = scope->FindMutableTensor(y_var_name);
y_const_node->set_attr_value(lite::npu::CvtTensor(y));
auto x = scope->FindTensor(x_var_name);
auto y = scope->FindMutableTensor(y_var_name);
auto y_new_shape = CvtYShape(*x, y, axis);
y_const_node->set_attr_value(lite::npu::CvtTensor(y, y_new_shape));
y_node = y_const_node;
}
lite::npu::OpList::Global().add(x_node);
......
......@@ -45,7 +45,7 @@ void elementwise_add_ref(const std::shared_ptr<operators::ElementwiseOp> op) {
if (axis < 0) {
axis += x_dims.size();
}
int batch = x_dims[0] / y_dims[0];
int batch = 1;
int channels = y->numel();
int num = x->numel() / channels / batch;
// do elementwise add/sub/max...
......@@ -143,8 +143,8 @@ void test_elementwise_add(const std::vector<int64_t>& x_shape,
y->Resize(y_shape);
// initialize input&output data
FillTensor<float>(x, 1, 5);
FillTensor<float>(y, 1, 5);
FillTensor<float>(x, 1, 3);
FillTensor<float>(y, 1, 3);
// initialize op desc
cpp::OpDesc opdesc;
......@@ -171,6 +171,7 @@ void test_elementwise_add(const std::vector<int64_t>& x_shape,
TEST(NPUBridges, elementwise_add) {
for (auto elt_type : {"add", "sub", "mul", "div"}) {
test_elementwise_add({1, 2, 3, 4}, {2}, 1, elt_type);
test_elementwise_add({1, 2, 3, 4}, {1, 2, 1, 1}, 1, elt_type);
test_elementwise_add({1, 2, 3, 4}, {1, 2, 3, 4}, 3, elt_type);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册