未验证 提交 f3e5d3e5 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU] Use FullConnection op to solve the compatibility between Kirin 810 and 990 (#2283)

上级 756140a8
......@@ -31,38 +31,50 @@
// Extended Ops of HIAI DDK
namespace ge {
/**
* Multiply the matrix x1 by the matrix x2 to generate x1 * x2.
* The inputs must be two-dimensional matrices and the inner dimension of "x1"
* (after being transposed if transpose_x1 is true) must match the outer
* dimension of "x2" (after being transposed if transposed_x2 is true). <Input>
* x : the first input tensor, must be non const op.
* w : the second input tensor, must be const op.
* bias: the optional bias tensor, must be const op.
* <Output>
* y : the output tensor.
* <Attr>
* has_bias: If true, enable input bias.
*/
REG_OP(MatMul)
.INPUT(x, TensorType({DT_FLOAT}))
.INPUT(w, TensorType({DT_FLOAT}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT})) // bias must be const input
.OUTPUT(y, TensorType({DT_FLOAT}))
.ATTR(has_bias, AttrValue::BOOL{false}) // when has input::bias,set true
.OP_END()
/**
* Computes the gradients of convolution with respect to the input.
* <Input>
* input_sizes : An integer vector representing the shape of input,
* where input is a 4-D [batch, height, width, channels] tensor.
* filter : the filter tensor, with shape [H , W, filter_channel,
* filter_number], filter_channel must be same as x channel.
* x : The input tensor.
* <Output>
* y : The output tensor.
* <Attr>
* format: 0: NCHW. 1: NHWC
* group : 1: default
* num_output : 0: default, num_output must be equal to
* (filter_channel * group)
* pad : Padding for the beginning and ending along each axis
* stride : Stride along each axis.
* dilation : dilation value along each axis of the filter.
* pad_mode : 0:NOTSET, 5:VALID 6:SAME. defaul value is 0:NOTSET
* bias_term : 0: default
* kernel : The shape of the convolution kernel
*/
REG_OP(Deconvolution)
.OP_END();
/**
* Computes the gradients of convolution with respect to the input.
* <Input>
* input_sizes : An integer vector representing the shape of input,
* where input is a 4-D [batch, height, width, channels] tensor.
* filter : the filter tensor, with shape [H , W, filter_channel,
* filter_number], filter_channel must be same as x channel.
* x : The input tensor.
* <Output>
* y : The output tensor.
* <Attr>
* format: 0: NCHW. 1: NHWC
* group : 1: default
* num_output : 0: default, num_output must be equal to
* (filter_channel * group)
* pad : Padding for the beginning and ending along each axis
* stride : Stride along each axis.
* dilation : dilation value along each axis of the filter.
* pad_mode : 0:NOTSET, 5:VALID 6:SAME. defaul value is 0:NOTSET
* bias_term : 0: default
* kernel : The shape of the convolution kernel
*/
REG_OP(Deconvolution)
.INPUT(input_sizes, TensorType({DT_UINT8}))
.INPUT(filter, TensorType({DT_FLOAT}))
.INPUT(x, TensorType({DT_FLOAT}))
......@@ -78,28 +90,28 @@ REG_OP(MatMul)
.ATTR(pad_mode, AttrValue::INT{0})
.ATTR(bias_term, AttrValue::INT{0})
.ATTR(kernel, AttrValue::LIST_INT({0, 0}))
.OP_END()
/**
* Resize images to size using bilinear interpolation.
* <Input>
* x : The tensor of 4-D
* w : A int32 Tensor of 2 elements: [height, width].
* <Output>
* y : the output tensor
* <Attr>
* align_corners : If true, the centers of the 4 corner pixels of the
* input and output tensors are aligned, preserving the values at the corner
* pixels.
* output_dim_mode : Defaults 2, including 0: zoom_factor , 1:
* shrink_factor, 2: height/width. when output_dim_mode=2, the output-dim is
* controled by the [height, width] of w.
* shrink_factor : shrink factor.
* zoom_factor : zoom factor.
* pad_begin : begin of pad.
* pad_end : end of pad.
*/
REG_OP(ResizeBilinear)
.OP_END();
/**
* Resize images to size using bilinear interpolation.
* <Input>
* x : The tensor of 4-D
* w : A int32 Tensor of 2 elements: [height, width].
* <Output>
* y : the output tensor
* <Attr>
* align_corners : If true, the centers of the 4 corner pixels of the
* input and output tensors are aligned, preserving the values at the corner
* pixels.
* output_dim_mode : Defaults 2, including 0: zoom_factor , 1:
* shrink_factor, 2: height/width. when output_dim_mode=2, the output-dim is
* controled by the [height, width] of w.
* shrink_factor : shrink factor.
* zoom_factor : zoom factor.
* pad_begin : begin of pad.
* pad_end : end of pad.
*/
REG_OP(ResizeBilinear)
.INPUT(x, TensorType({DT_FLOAT, DT_INT32}))
.INPUT(w, TensorType({DT_FLOAT, DT_INT32}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32}))
......@@ -109,42 +121,42 @@ REG_OP(MatMul)
.ATTR(zoom_factor, AttrValue::INT{1})
.ATTR(pad_begin, AttrValue::INT{0})
.ATTR(pad_end, AttrValue::INT{0})
.OP_END()
/**
* Resize images to size using nearest neighbor interpolation.
* <Input>
* image : Resize images to size using nearest neighbor interpolation.
* size : Must be one dimension and two elements
* <Output>
* output : the output tensor
* <Attr>
* align_corners : If true, the centers of the 4 corner pixels of the
* input and output tensors are aligned, preserving the values at the corner
* pixels. Defaults to false
*/
REG_OP(ResizeNearestNeighbor)
.OP_END();
/**
* Resize images to size using nearest neighbor interpolation.
* <Input>
* image : Resize images to size using nearest neighbor interpolation.
* size : Must be one dimension and two elements
* <Output>
* output : the output tensor
* <Attr>
* align_corners : If true, the centers of the 4 corner pixels of the
* input and output tensors are aligned, preserving the values at the corner
* pixels. Defaults to false
*/
REG_OP(ResizeNearestNeighbor)
.INPUT(image, TensorType({DT_FLOAT, DT_INT32, DT_UINT8, DT_BOOL}))
.INPUT(size, TensorType({DT_INT32}))
.OUTPUT(output, TensorType({DT_FLOAT, DT_INT32, DT_UINT8, DT_BOOL}))
.ATTR(align_corners, AttrValue::BOOL{false})
.OP_END()
/**
* Pads a tensor.
* <Input>
* x : the input tensor
* padding : the input tensor must be 2-D
* constant_values : constant values must be a scalar
* <Output>
* output : the output tensor
* <Attr>
* t_paddings : Default DT_INT32 , t_paddings must be the same with
* datatype of the padding
* mode : 0: CONSTANT, 1: REFLECT, 2: SYMMETRIC
* T : datatype of constant_values DT_INT32:3 DT_FLOAT:0
*/
REG_OP(Pad)
.OP_END();
/**
* Pads a tensor.
* <Input>
* x : the input tensor
* padding : the input tensor must be 2-D
* constant_values : constant values must be a scalar
* <Output>
* output : the output tensor
* <Attr>
* t_paddings : Default DT_INT32 , t_paddings must be the same with
* datatype of the padding
* mode : 0: CONSTANT, 1: REFLECT, 2: SYMMETRIC
* T : datatype of constant_values DT_INT32:3 DT_FLOAT:0
*/
REG_OP(Pad)
.INPUT(x, TensorType({DT_FLOAT, DT_INT32}))
.INPUT(padding, TensorType({DT_INT32}))
.OPTIONAL_INPUT(constant_values, TensorType({DT_INT32, DT_FLOAT}))
......@@ -152,7 +164,7 @@ REG_OP(MatMul)
.ATTR(t_paddings, AttrValue::INT{3})
.ATTR(mode, AttrValue::INT{0})
.REQUIRED_ATTR(T, AttrValue::INT)
.OP_END()
.OP_END();
} // namespace ge
......
......@@ -93,11 +93,13 @@ void CompareOutputTensor(
auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape());
EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size);
for (size_t j = 0; j < ref_output_tensor_size; j++) {
auto diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) /
(std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << diff;
EXPECT_LT(diff, 0.1);
auto abs_diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]);
auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << "val: " << tar_output_tensor_data[j]
<< " ref: " << ref_output_tensor_data[j]
<< " abs_diff: " << abs_diff << " rel_diff: " << rel_diff;
EXPECT_LT(rel_diff, 0.1);
}
}
}
......
......@@ -23,20 +23,22 @@ namespace bridges {
node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
const node_map_type& inputs_map) {
LOG(INFO) << "Converting fc...";
lite::Scope* scope = fc_op->scope();
const lite::OpInfo* op_info = fc_op->op_info();
auto output_node =
std::make_shared<ge::op::MatMul>(lite::npu::UniqueName("fc"));
auto scope = fc_op->scope();
auto op_info = fc_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
auto fc_node = std::make_shared<ge::op::FullConnection>(unique_op_type);
auto x_var_name = op_info->Input("Input").front();
auto w_var_name = op_info->Input("W").front();
int in_num_col_dims = op_info->GetAttr<int>("in_num_col_dims");
auto* xtensor = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto* wtensor = scope->FindVar(w_var_name)->GetMutable<lite::Tensor>();
auto x_dims = xtensor->dims();
auto w_dims = wtensor->dims();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto w = scope->FindVar(w_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
auto w_dims = w->dims();
CHECK_GE(x_dims.size(), 2UL);
CHECK_EQ(w_dims.size(), 2UL);
......@@ -44,65 +46,69 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
int m = x_dims.Slice(0, in_num_col_dims).production();
int k = x_dims.Slice(in_num_col_dims, x_dims.size()).production();
int n = w_dims[1];
CHECK_EQ(k * n, w_dims.production());
VLOG(3) << "x dims: " << x_dims << " w dims: " << w_dims << " m: " << m
<< " k: " << k << " n: " << n;
CHECK(inputs_map.count(x_var_name));
CHECK(!inputs_map.count(w_var_name));
LOG(INFO) << "m:" << m << ",n:" << n << ",k:" << k;
LOG(INFO) << "x_var_name:" << x_var_name
<< ", is data: " << inputs_map.count(x_var_name);
LOG(INFO) << "w_var_name:" << w_var_name
<< ", is data: " << inputs_map.count(w_var_name);
auto xsrc = inputs_map.at(x_var_name);
auto reshapex = std::make_shared<ge::op::Reshape>(x_var_name + "_reshape");
reshapex->set_input_tensor(*xsrc);
reshapex->set_attr_shape({m, k});
reshapex->set_attr_axis(0);
lite::npu::OpList::Global().add(xsrc);
lite::npu::OpList::Global().add(reshapex);
output_node->set_input_x(*reshapex);
auto wconst = std::make_shared<ge::op::Const>(w_var_name);
ge::TensorDesc wdesc(ge::Shape({k, n}), ge::FORMAT_NCHW, ge::DT_FLOAT);
auto size = wdesc.GetShape().GetShapeSize();
CHECK_EQ(size, w_dims.production());
ge::TensorPtr ptensor = std::make_shared<ge::Tensor>();
ptensor->SetTensorDesc(wdesc);
auto* pdata = reinterpret_cast<uint8_t*>(wtensor->mutable_data<float>());
ptensor->SetData(pdata, size * sizeof(float));
wconst->set_attr_value(ptensor);
lite::npu::OpList::Global().add(wconst);
output_node->set_input_w(*wconst);
// reshape x to (m, k, 1, 1)
auto reshaped_x_node =
std::make_shared<ge::op::Reshape>(x_var_name + "_reshape");
reshaped_x_node->set_input_tensor(*inputs_map.at(x_var_name));
reshaped_x_node->set_attr_shape({m, k, 1, 1});
reshaped_x_node->set_attr_axis(0);
fc_node->set_input_x(*reshaped_x_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(reshaped_x_node);
// create w const node, set its shape to (k, n, 1, 1) and fill with
// the transposed w tensor
auto w_const_node = std::make_shared<ge::op::Const>(w_var_name);
ge::TensorDesc w_const_desc(
ge::Shape({n, k, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::TensorPtr w_const_tensor = std::make_shared<ge::Tensor>();
w_const_tensor->SetTensorDesc(w_const_desc);
auto w_data = w->mutable_data<float>();
std::vector<float> transposed_w_data(w_dims.production());
for (int i = 0; i < k; i++) {
for (int j = 0; j < n; j++) {
transposed_w_data[j * k + i] = w_data[i * n + j];
}
}
w_const_tensor->SetData(reinterpret_cast<uint8_t*>(transposed_w_data.data()),
transposed_w_data.size() * sizeof(float));
w_const_node->set_attr_value(w_const_tensor);
fc_node->set_input_w(*w_const_node);
lite::npu::OpList::Global().add(w_const_node);
// add bias node if bias tensor exists
if (lite::npu::HasInputArg(op_info, scope, "Bias")) {
auto b_var_name = op_info->Input("Bias").front();
auto* btensor = scope->FindVar(b_var_name)->GetMutable<lite::Tensor>();
LOG(INFO) << "b_var_name:" << b_var_name
<< ", is data: " << inputs_map.count(b_var_name);
CHECK(!inputs_map.count(b_var_name));
CHECK_EQ(btensor->numel(), n);
auto bconst = std::make_shared<ge::op::Const>(b_var_name);
ge::TensorDesc bdesc(
ge::Shape({1, n, 1, 1}), ge::FORMAT_NCHW, ge::DT_FLOAT);
auto size = bdesc.GetShape().GetShapeSize();
CHECK_EQ(size, n);
ge::TensorPtr ptensor = std::make_shared<ge::Tensor>();
ptensor->SetTensorDesc(bdesc);
auto* pdata = reinterpret_cast<uint8_t*>(btensor->mutable_data<float>());
ptensor->SetData(pdata, size * sizeof(float));
bconst->set_attr_value(ptensor);
lite::npu::OpList::Global().add(bconst);
output_node->set_input_bias(*bconst);
output_node->set_attr_has_bias(ge::AttrValue::BOOL{true});
auto bias_var_name = op_info->Input("Bias").front();
auto bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto bias_dims = bias->dims();
CHECK(!inputs_map.count(bias_var_name));
CHECK_EQ(bias_dims.production(), n);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(
lite::npu::CvtFromLiteTensor(bias, {1, n, 1, 1}));
fc_node->set_input_b(*bias_const_node);
lite::npu::OpList::Global().add(bias_const_node);
}
lite::npu::OpList::Global().add(fc_node);
lite::npu::OpList::Global().add(output_node);
// reshape output of fc_node from (m, n, 1, 1) to (m, n)
auto reshaped_fc_node =
std::make_shared<ge::op::Reshape>(unique_op_type + "_reshape");
reshaped_fc_node->set_input_tensor(*fc_node);
reshaped_fc_node->set_attr_shape({m, n});
reshaped_fc_node->set_attr_axis(0);
lite::npu::OpList::Global().add(reshaped_fc_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = output_node;
outputs_map[op_info->Output("Out").front()] = reshaped_fc_node;
return outputs_map;
}
......
......@@ -126,6 +126,7 @@ TEST(NPUBridges, fc) {
test_fc({1, 8, 8, 1}, {8, 4}, 2, use_bias);
test_fc({1, 5, 5, 1}, {5, 7}, 2, use_bias);
test_fc({1, 4, 1, 1}, {4, 8}, 1, use_bias);
test_fc({1, 1024, 1, 1}, {1024, 1000}, 1, use_bias);
}
}
......
......@@ -43,7 +43,7 @@ void LauchOp(const std::shared_ptr<lite::OpLite> op,
ge::Shape(input->dims().Vectorize()), ge::FORMAT_NCHW, ge::DT_FLOAT);
auto input_node = std::make_shared<ge::op::Data>(input_var_name);
input_node->update_input_desc_x(input_desc);
OpList::Global().add(input_node);
lite::npu::OpList::Global().add(input_node);
inputs_map[input_var_name] = input_node;
}
auto outputs_map = supported_lists.at(op_type)(op, inputs_map);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册