提交 1783d8f3 编写于 作者: B Bin Li

Avoid dividing by zero for cosine similarity

上级 1be8257c
...@@ -238,9 +238,8 @@ bool HexagonControlWrapper::SetupGraph(const NetDef &net_def, ...@@ -238,9 +238,8 @@ bool HexagonControlWrapper::SetupGraph(const NetDef &net_def,
// input info // input info
num_inputs_ = 0; num_inputs_ = 0;
for (const InputInfo &input_info : net_def.input_info()) { for (const InputInfo &input_info : net_def.input_info()) {
std::vector<index_t> input_shape; std::vector<index_t> input_shape(input_info.dims().begin(),
input_shape.insert(input_shape.begin(), input_info.dims().begin(), input_info.dims().end());
input_info.dims().end());
while (input_shape.size() < 4) { while (input_shape.size() < 4) {
input_shape.insert(input_shape.begin(), 1); input_shape.insert(input_shape.begin(), 1);
} }
...@@ -252,9 +251,8 @@ bool HexagonControlWrapper::SetupGraph(const NetDef &net_def, ...@@ -252,9 +251,8 @@ bool HexagonControlWrapper::SetupGraph(const NetDef &net_def,
// output info // output info
num_outputs_ = 0; num_outputs_ = 0;
for (const OutputInfo &output_info : net_def.output_info()) { for (const OutputInfo &output_info : net_def.output_info()) {
std::vector<index_t> output_shape; std::vector<index_t> output_shape(output_info.dims().begin(),
output_shape.insert(output_shape.begin(), output_info.dims().begin(), output_info.dims().end());
output_info.dims().end());
while (output_shape.size() < 4) { while (output_shape.size() < 4) {
output_shape.insert(output_shape.begin(), 1); output_shape.insert(output_shape.begin(), 1);
} }
...@@ -478,12 +476,11 @@ bool HexagonControlWrapper::ExecuteGraphNew( ...@@ -478,12 +476,11 @@ bool HexagonControlWrapper::ExecuteGraphNew(
&outputs[index + 3]); &outputs[index + 3]);
} }
int res = int res = hexagon_nn_execute_new(nn_id_,
hexagon_nn_execute_new(nn_id_, inputs.data(),
inputs.data(), num_inputs * NUM_METADATA,
num_inputs * NUM_METADATA, outputs.data(),
outputs.data(), num_outputs * NUM_METADATA);
num_outputs * NUM_METADATA);
for (size_t i = 0; i < num_outputs; ++i) { for (size_t i = 0; i < num_outputs; ++i) {
size_t index = i * NUM_METADATA; size_t index = i * NUM_METADATA;
......
...@@ -138,6 +138,7 @@ inline bool IsSameSize(const Tensor &x, const Tensor &y) { ...@@ -138,6 +138,7 @@ inline bool IsSameSize(const Tensor &x, const Tensor &y) {
inline std::string ShapeToString(const Tensor &x) { inline std::string ShapeToString(const Tensor &x) {
std::stringstream stream; std::stringstream stream;
stream << "[";
for (int i = 0; i < x.dim_size(); i++) { for (int i = 0; i < x.dim_size(); i++) {
if (i > 0) stream << ","; if (i > 0) stream << ",";
int64_t dim = x.dim(i); int64_t dim = x.dim(i);
...@@ -174,8 +175,8 @@ inline void ExpectEqual<double>(const double &a, const double &b) { ...@@ -174,8 +175,8 @@ inline void ExpectEqual<double>(const double &a, const double &b) {
} }
inline void AssertSameDims(const Tensor &x, const Tensor &y) { inline void AssertSameDims(const Tensor &x, const Tensor &y) {
ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs " ASSERT_TRUE(IsSameSize(x, y)) << "x.shape " << ShapeToString(x) << " vs "
<< "y.shape [ " << ShapeToString(y) << "]"; << "y.shape " << ShapeToString(y);
} }
template<typename EXP_TYPE, template<typename EXP_TYPE,
...@@ -282,7 +283,7 @@ void ExpectTensorNear(const Tensor &x, ...@@ -282,7 +283,7 @@ void ExpectTensorNear(const Tensor &x,
template<typename T> template<typename T>
void ExpectTensorSimilar(const Tensor &x, void ExpectTensorSimilar(const Tensor &x,
const Tensor &y, const Tensor &y,
const double abs_err = 1e-5) { const double rel_err = 1e-5) {
AssertSameDims(x, y); AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x); Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y); Tensor::MappingGuard y_mapper(&y);
...@@ -294,8 +295,11 @@ void ExpectTensorSimilar(const Tensor &x, ...@@ -294,8 +295,11 @@ void ExpectTensorSimilar(const Tensor &x,
x_norm += x_data[i] * x_data[i]; x_norm += x_data[i] * x_data[i];
y_norm += y_data[i] * y_data[i]; y_norm += y_data[i] * y_data[i];
} }
double similarity = dot_product / (sqrt(x_norm) * sqrt(y_norm)); double norm_product = sqrt(x_norm) * sqrt(y_norm);
EXPECT_NEAR(1.0, similarity, abs_err); double error = rel_err * std::abs(dot_product);
EXPECT_NEAR(dot_product, norm_product, error)
<< "Shape " << ShapeToString(x);
} }
} // namespace test } // namespace test
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册