未验证 提交 405e5de6 编写于 作者: C cc 提交者: GitHub

Check scale value in set scales, test=develop (#3885)

上级 79c54557
...@@ -18,18 +18,11 @@ ...@@ -18,18 +18,11 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/utils/string.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
static std::string int2string(int index) {
const int BUFFER_LENGTH = 30;
char buffer[BUFFER_LENGTH];
int num = snprintf(buffer, sizeof(buffer), "%d", index);
CHECK(num > 0 && num < sizeof(buffer));
return std::string(buffer);
}
bool OpLite::InferShape() { bool OpLite::InferShape() {
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_ // if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied. // InferShapeByMemoryInternal will be applied.
...@@ -245,7 +238,7 @@ bool OpInfo::HasInputScale(const std::string &input_name) const { ...@@ -245,7 +238,7 @@ bool OpInfo::HasInputScale(const std::string &input_name) const {
int index; int index;
if (GetInputArgname(input_name, &argname) && if (GetInputArgname(input_name, &argname) &&
GetInputIndex(input_name, &index)) { GetInputIndex(input_name, &index)) {
return HasAttr(argname + int2string(index) + "_scale"); return HasAttr(argname + to_string(index) + "_scale");
} else { } else {
return false; return false;
} }
...@@ -256,7 +249,7 @@ bool OpInfo::HasOutputScale(const std::string &output_name) const { ...@@ -256,7 +249,7 @@ bool OpInfo::HasOutputScale(const std::string &output_name) const {
int index; int index;
if (GetOutputArgname(output_name, &argname) && if (GetOutputArgname(output_name, &argname) &&
GetOutputIndex(output_name, &index)) { GetOutputIndex(output_name, &index)) {
return HasAttr(argname + int2string(index) + "_scale"); return HasAttr(argname + to_string(index) + "_scale");
} else { } else {
return false; return false;
} }
...@@ -268,7 +261,9 @@ void OpInfo::SetInputScale(const std::string &input_name, ...@@ -268,7 +261,9 @@ void OpInfo::SetInputScale(const std::string &input_name,
int index; int index;
CHECK(GetInputArgname(input_name, &argname)); CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index)); CHECK(GetInputIndex(input_name, &index));
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale", CHECK(scale_value.size() > 0)
<< "Error in SetInputScale: the scales should not be empty";
SetAttr<std::vector<float>>(argname + to_string(index) + "_scale",
scale_value); scale_value);
} }
...@@ -278,7 +273,9 @@ void OpInfo::SetOutputScale(const std::string &output_name, ...@@ -278,7 +273,9 @@ void OpInfo::SetOutputScale(const std::string &output_name,
int index; int index;
CHECK(GetOutputArgname(output_name, &argname)); CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index)); CHECK(GetOutputIndex(output_name, &index));
SetAttr<std::vector<float>>(argname + int2string(index) + "_scale", CHECK(scale_value.size() > 0)
<< "Error in SetOutputScale: the scales should not be empty";
SetAttr<std::vector<float>>(argname + to_string(index) + "_scale",
scale_value); scale_value);
} }
...@@ -287,7 +284,7 @@ std::vector<float> OpInfo::GetInputScale(const std::string &input_name) const { ...@@ -287,7 +284,7 @@ std::vector<float> OpInfo::GetInputScale(const std::string &input_name) const {
int index; int index;
CHECK(GetInputArgname(input_name, &argname)); CHECK(GetInputArgname(input_name, &argname));
CHECK(GetInputIndex(input_name, &index)); CHECK(GetInputIndex(input_name, &index));
return GetAttr<std::vector<float>>(argname + int2string(index) + "_scale"); return GetAttr<std::vector<float>>(argname + to_string(index) + "_scale");
} }
std::vector<float> OpInfo::GetOutputScale( std::vector<float> OpInfo::GetOutputScale(
...@@ -296,7 +293,7 @@ std::vector<float> OpInfo::GetOutputScale( ...@@ -296,7 +293,7 @@ std::vector<float> OpInfo::GetOutputScale(
int index; int index;
CHECK(GetOutputArgname(output_name, &argname)); CHECK(GetOutputArgname(output_name, &argname));
CHECK(GetOutputIndex(output_name, &index)); CHECK(GetOutputIndex(output_name, &index));
return GetAttr<std::vector<float>>(argname + int2string(index) + "_scale"); return GetAttr<std::vector<float>>(argname + to_string(index) + "_scale");
} }
} // namespace lite } // namespace lite
......
...@@ -60,6 +60,13 @@ static std::string to_string(const T& v) { ...@@ -60,6 +60,13 @@ static std::string to_string(const T& v) {
return ss.str(); return ss.str();
} }
static std::string to_string(int index) {
const int BUFFER_LENGTH = 15;
char buffer[BUFFER_LENGTH];
snprintf(buffer, sizeof(buffer), "%d", index);
return std::string(buffer);
}
template <typename T> template <typename T>
std::string Join(const std::vector<T>& vec, const std::string& delim) { std::string Join(const std::vector<T>& vec, const std::string& delim) {
if (vec.empty()) return ""; if (vec.empty()) return "";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册