未验证 提交 f90dd802 编写于 作者: R Ray Liu 提交者: GitHub

Merge pull request #1075 from xiebaiyuan/develop

develop
......@@ -57,7 +57,12 @@ class RawData {
public:
char data[size];
RawData() {}
RawData(const RawData &raw_data) { strcpy(data, raw_data.data); }
RawData(const RawData &raw_data) { memcpy(data, raw_data.data, size); }
RawData &operator=(const RawData &raw_data) {
memcpy(data, raw_data.data, size);
return *this;
}
};
template <typename... Ts>
......@@ -74,14 +79,36 @@ struct Variant {
template <typename T, typename... Args>
void Set(Args &&... args) {
helper::Destroy(type_id, &data);
new (&data) T(std::forward<Args>(args)...);
helper::Destroy(type_id, &data.data);
new (&data.data) T(std::forward<Args>(args)...);
type_id = typeid(T).hash_code();
}
void SetString(std::string &string) {
// helper::Destroy(type_id, &data);
type_id = typeid(std::string).hash_code();
strcpy(data.data, string.c_str());
}
std::string GetString() const {
if (type_id == typeid(std::string).hash_code()) {
return std::string(data.data);
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
" bad cast in variant data type not a string ");
exit(0);
}
}
template <typename T>
T &Get() const {
if (type_id == typeid(T).hash_code()) {
if (type_id == typeid(std::string).hash_code()) {
PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
"stl lib with string copy)");
exit(0);
} else if (type_id == typeid(T).hash_code()) {
return *const_cast<T *>(reinterpret_cast<const T *>(&data));
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" bad cast in variant");
......
......@@ -51,7 +51,7 @@ class Attribute {
break;
}
case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__STRING: {
attr.Set<std::string>(std::string(attr_desc->s));
attr.SetString(std::string(attr_desc->s));
break;
}
case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BOOLEANS: {
......@@ -108,6 +108,13 @@ class Attribute {
return variant_.Get<T>();
}
Attribute &SetString(std::string string) {
variant_.SetString(string);
return *this;
}
std::string GetString() const { return variant_.GetString(); }
template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) {
if (attr.variant_.TypeId() == typeid(int).hash_code()) {
......@@ -115,7 +122,7 @@ class Attribute {
} else if (attr.variant_.TypeId() == typeid(float).hash_code()) {
return vistor(attr.variant_.Get<float>());
} else if (attr.variant_.TypeId() == typeid(string).hash_code()) {
return vistor(attr.variant_.Get<string>());
return vistor(attr.variant_.GetString());
} else if (attr.variant_.TypeId() == typeid(vector<int>).hash_code()) {
return vistor(attr.variant_.Get<vector<int>>());
} else if (attr.variant_.TypeId() == typeid(vector<float>).hash_code()) {
......
......@@ -33,6 +33,13 @@ class Variable {
template <typename T>
const T GetValue() const {
if (typeid(T) == typeid(std::string)) {
PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
"stl lib with string copy)");
exit(0);
}
return variant.Get<T>();
}
......
......@@ -263,6 +263,10 @@ class OpParam {
static const T GetAttr(const string &key, const AttributeMap &map) {
return ((Attribute)map.at(key)).Get<T>();
}
static const std::string GetStringAttr(const string &key,
const AttributeMap &map) {
return ((Attribute)map.at(key)).GetString();
}
static const bool HasAttr(const string &key, const AttributeMap &map) {
return map.count(key) > 0;
......@@ -502,7 +506,7 @@ class LrnParam : public OpParam {
alpha_ = GetAttr<float>("alpha", attrs);
beta_ = GetAttr<float>("beta", attrs);
k_ = GetAttr<float>("k", attrs);
data_format_ = GetAttr<string>("data_format", attrs);
data_format_ = GetStringAttr("data_format", attrs);
}
const RType *InputX() const { return input_x_; }
......@@ -599,7 +603,7 @@ class PoolParam : public OpParam {
input_ = InputXFrom<GType>(inputs, scope);
output_ = OutFrom<GType>(outputs, scope);
pooling_type_ = GetAttr<string>("pooling_type", attrs);
pooling_type_ = GetStringAttr("pooling_type", attrs);
ksize_ = GetAttr<vector<int>>("ksize", attrs);
strides_ = GetAttr<vector<int>>("strides", attrs);
paddings_ = GetAttr<vector<int>>("paddings", attrs);
......@@ -733,7 +737,7 @@ class BoxCoderParam : public OpParam {
input_priorboxvar_ = InputPriorBoxVarFrom<GType>(inputs, scope);
input_targetbox_ = InputTargetBoxFrom<GType>(inputs, scope);
output_box_ = OutputBoxFrom<GType>(outputs, scope);
code_type_ = GetAttr<std::string>("code_type", attrs);
code_type_ = GetStringAttr("code_type", attrs);
}
const RType *InputPriorBox() const { return input_priorbox_; }
......@@ -1208,7 +1212,7 @@ class PReluParam : public OpParam {
alpha_ = InputAlphaFrom<GType>(inputs, scope);
framework::DDim dims = alpha_->dims();
out_ = OutFrom<GType>(outputs, scope);
mode_ = GetAttr<std::string>("mode", attrs);
mode_ = GetStringAttr("mode", attrs);
DLOG << "PReluParam mode after" << mode_;
}
const RType *InputX() const { return input_x_; }
......@@ -1339,7 +1343,7 @@ class FusionConvAddPReluParam : public ConvParam<Dtype> {
const AttributeMap &attrs, const Scope &scope)
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope);
mode_ = OpParam::GetAttr<std::string>("mode", attrs);
mode_ = OpParam::GetStringAttr("mode", attrs);
framework::DDim dims = alpha_->dims();
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
axis_ = OpParam::GetAttr<int>("axis", attrs);
......@@ -1382,7 +1386,7 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
: ConvParam<Dtype>(inputs, outputs, attrs, scope) {
bias1_ = OpParam::InputYFrom1<GType>(inputs, scope);
alpha_ = OpParam::InputAlphaFrom<GType>(inputs, scope);
mode_ = OpParam::GetAttr<std::string>("mode", attrs);
mode_ = OpParam::GetStringAttr("mode", attrs);
framework::DDim dims = alpha_->dims();
bias_ = OpParam::InputYFrom<GType>(inputs, scope);
output_ = OpParam::OutFrom<GType>(outputs, scope);
......@@ -1989,8 +1993,8 @@ class GruParam : public OpParam {
OutputBatchResetHiddenPrevFrom<GType>(outputs, scope);
output_batch_hidden_ = OutputBatchHiddenFrom<GType>(outputs, scope);
output_hidden_ = OutputHiddenFrom<GType>(outputs, scope);
activation_ = GetAttr<std::string>("activation", attrs);
gate_activation_ = GetAttr<std::string>("gate_activation", attrs);
activation_ = GetStringAttr("activation", attrs);
gate_activation_ = GetStringAttr("gate_activation", attrs);
is_reverse_ = GetAttr<bool>("is_reverse", attrs);
}
const GType *InputInput() const { return input_input_; }
......
......@@ -60,7 +60,15 @@ int main() {
std::cout << "load cost :" << time_diff(time1, time1) << "ms" << std::endl;
// 1064 1603 644 699 2878 1219 867 1352 8 1 13 312 479
std::vector<int64_t> ids{1791, 656, 1549, 281, 96};
std::vector<int64_t> ids{
2084, 635, 1035, 197, 990, 150, 1132, 2403, 546, 770, 4060, 3352,
1798, 1589, 1352, 98, 136, 3461, 3186, 1159, 515, 764, 278, 1178,
5044, 4060, 943, 932, 463, 1198, 3352, 374, 1198, 3352, 374, 2047,
1069, 1589, 3672, 1178, 1178, 2165, 1178, 2084, 635, 3087, 2236, 546,
2047, 1549, 546, 2047, 302, 2202, 398, 804, 397, 657, 804, 866,
932, 2084, 515, 2165, 397, 302, 2202, 526, 992, 906, 1215, 1589,
4493, 2403, 723, 932, 2084, 635, 1352, 932, 444, 2047, 1159, 1893,
1579, 59, 330, 98, 1296, 1159, 3430, 738, 3186, 1071, 2174, 3933};
paddle_mobile::framework::LoDTensor words;
auto size = static_cast<int>(ids.size());
......
......@@ -46,7 +46,7 @@ class TestBoxCoderOp {
DLOG << " Input TargetBox is : " << op->Input("TargetBox")[0];
DLOG << " OutputBox is : " << op->Output("OutputBox")[0];
DLOG << " code_type : "
<< op->GetAttrMap().at("code_type").Get<std::string>();
<< op->GetAttrMap().at("code_type").GetString();
std::shared_ptr<operators::BoxCoderOp<Dtype, float>> boxcoder =
std::make_shared<operators::BoxCoderOp<Dtype, float>>(
op->Type(), op->GetInputs(), op->GetOutputs(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册