提交 1bd4f908 编写于 作者: B Bin Li

Fix hexagon supernode without biasadd

上级 97e2d5d4
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import numpy as np
from operator import mul
from mace.proto import mace_pb2
from mace.python.tools.converter_tool import base_converter
......@@ -24,9 +27,6 @@ from mace.python.tools.converter_tool.base_converter import PoolingType
from mace.python.tools.convert_util import mace_check
from mace.python.tools import graph_util
import copy
from operator import mul
class HexagonOps(object):
def __init__(self):
......@@ -133,6 +133,8 @@ class HexagonConverter(base_converter.ConverterInterface):
print('Supernode requires biasadd, we add it.')
bias_data = np.zeros(channels, dtype=int)
bias_tensor = self._model.tensors.add()
bias_tensor.data_type = mace_pb2.DT_INT32
bias_tensor.dims.extend([channels])
bias_tensor.int32_data.extend(bias_data)
bias_tensor.minval = 0
bias_tensor.maxval = 0
......
......@@ -66,9 +66,22 @@ def calculate_similarity(u, v, data_type=np.float64):
return np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))
def calculate_pixel_accuracy(out_value, mace_out_value):
out_value = out_value.reshape((-1, out_value.shape[-1]))
batches = out_value.shape[0]
classes = out_value.shape[1]
mace_out_value = mace_out_value.reshape((batches, classes))
correct_count = 0
for i in range(batches):
if np.argmax(out_value[i]) == np.argmax(mace_out_value[i]):
correct_count += 1
return 1.0 * correct_count / batches
def compare_output(platform, device_type, output_name, mace_out_value,
out_value, validation_threshold):
if mace_out_value.size != 0:
pixel_accuracy = calculate_pixel_accuracy(out_value, mace_out_value)
out_value = out_value.reshape(-1)
mace_out_value = mace_out_value.reshape(-1)
assert len(out_value) == len(mace_out_value)
......@@ -76,7 +89,8 @@ def compare_output(platform, device_type, output_name, mace_out_value,
similarity = calculate_similarity(out_value, mace_out_value)
common.MaceLogger.summary(
output_name + ' MACE VS ' + platform.upper()
+ ' similarity: ' + str(similarity) + ' , sqnr: ' + str(sqnr))
+ ' similarity: ' + str(similarity) + ' , sqnr: ' + str(sqnr)
+ ' , pixel_accuracy: ' + str(pixel_accuracy))
if similarity > validation_threshold:
common.MaceLogger.summary(
common.StringFormatter.block("Similarity Test Passed"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册