未验证 提交 80d5f6fb 编写于 作者: Y Yanzhan Yang 提交者: GitHub

fix variant memory leak (#1951)

* fix variant memory leak test=develop

* fix variant memory leak & fix style test=develop

* fix cpplint test=develop
上级 60bbc691
......@@ -16,7 +16,9 @@ limitations under the License. */
#include <cstdlib>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include "common/enforce.h"
#include "common/log.h"
#include "common/type_define.h"
......@@ -30,41 +32,33 @@ struct IDToType {
template <typename F, typename... Ts>
struct VariantHelper {
static const size_t size = sizeof(F) > VariantHelper<Ts...>::size
? sizeof(F)
: VariantHelper<Ts...>::size;
inline static void Destroy(kTypeId_t type, void *data) {
inline static void Destroy(kTypeId_t type, void *raw_ptr) {
if (type == type_id<F>()) {
reinterpret_cast<F *>(data)->~F();
auto ptr = reinterpret_cast<F *>(raw_ptr);
delete ptr;
} else {
VariantHelper<Ts...>::Destroy(type, data);
VariantHelper<Ts...>::Destroy(type, raw_ptr);
}
}
};
template <typename F>
struct VariantHelper<F> {
static const size_t size = sizeof(F);
inline static void Destroy(kTypeId_t type, void *data) {
inline static void Destroy(kTypeId_t type, void *raw_ptr) {
if (type == type_id<F>()) {
// reinterpret_cast<F*>(data)->~F();
} else {
// std::cout << "未匹配到 " << std::endl;
auto ptr = reinterpret_cast<F *>(raw_ptr);
delete ptr;
}
}
};
template <size_t size>
class RawData {
public:
char data[size]; // NOLINT
RawData() {}
RawData(const RawData &raw_data) { memcpy(data, raw_data.data, size); }
RawData &operator=(const RawData &raw_data) {
memcpy(data, raw_data.data, size);
return *this;
template <typename... Ts>
struct VariantDeleter {
kTypeId_t type_ = type_id<void>().hash_code();
explicit VariantDeleter(kTypeId_t type) { type_ = type; }
void operator()(void *raw_ptr) {
// DLOG << "variant delete: " << type_ << " " << raw_ptr;
VariantHelper<Ts...>::Destroy(type_, raw_ptr);
}
};
......@@ -78,43 +72,21 @@ struct Variant {
}
virtual ~Variant() {
// helper::Destroy(type_id, &data);
// DLOG << "variant deinit: " << type_ << " " << (void *)data_.get();
data_.reset();
}
template <typename T, typename... Args>
void Set(Args &&... args) {
helper::Destroy(type_, data_.data);
new (data_.data) T(std::forward<Args>(args)...);
auto raw_ptr = new T(std::forward<Args>(args)...);
type_ = type_id<T>().hash_code();
}
void SetString(const std::string &string) {
helper::Destroy(type_, data_.data);
type_ = type_id<std::string>().hash_code();
strcpy(data_.data, string.c_str()); // NOLINT
}
std::string GetString() const {
if (type_ == type_id<std::string>()) {
return std::string(data_.data);
} else {
PADDLE_MOBILE_THROW_EXCEPTION(
" bad cast in variant data type not a string ");
exit(0);
}
// DLOG << "variant new: " << type_ << " " << (void *)raw_ptr;
data_.reset(raw_ptr, VariantDeleter<Ts...>(type_));
}
template <typename T>
T &Get() const {
if (type_ == type_id<std::string>()) {
PADDLE_MOBILE_THROW_EXCEPTION(
"Please use getString to get an string (to avoid of an issue with "
"gcc "
"stl lib with string copy)");
exit(0);
} else {
return *const_cast<T *>(reinterpret_cast<const T *>(data_.data));
}
return *const_cast<T *>(reinterpret_cast<const T *>(data_.get()));
}
kTypeId_t TypeId() const { return type_; }
......@@ -123,8 +95,7 @@ struct Variant {
static inline kTypeId_t invalid_type() { return type_id<void>().hash_code(); }
typedef VariantHelper<Ts...> helper;
kTypeId_t type_ = type_id<void>().hash_code();
// todo use an anto size to suite this.
RawData<64> data_;
std::shared_ptr<void> data_;
};
template <typename T>
......
......@@ -51,7 +51,7 @@ class Attribute {
break;
}
case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__STRING: {
attr.SetString(std::string(attr_desc->s));
attr.Set<std::string>(attr_desc->s);
break;
}
case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BOOLEANS: {
......@@ -119,12 +119,7 @@ class Attribute {
return variant_.Get<T>();
}
Attribute &SetString(std::string string) {
variant_.SetString(string);
return *this;
}
std::string GetString() const { return variant_.GetString(); }
std::string GetString() const { return variant_.Get<std::string>(); }
template <typename Vistor>
static typename Vistor::type_t ApplyVistor(Vistor vistor, Attribute attr) {
......@@ -133,7 +128,7 @@ class Attribute {
} else if (attr.variant_.TypeId() == type_id<float>()) { // NOLINT
return vistor(attr.variant_.Get<float>());
} else if (attr.variant_.TypeId() == type_id<string>()) {
return vistor(attr.variant_.GetString());
return vistor(attr.variant_.Get<std::string>());
} else if (attr.variant_.TypeId() == type_id<vector<int>>()) {
return vistor(attr.variant_.Get<vector<int>>());
} else if (attr.variant_.TypeId() == type_id<vector<float>>()) {
......
......@@ -17,8 +17,7 @@ limitations under the License. */
namespace paddle_mobile {
namespace framework {
const DDim &CLImageConverterDefault::InitImageDimInfoWith(
const DDim &tensor_dim) {
DDim CLImageConverterDefault::InitImageDimInfoWith(const DDim &tensor_dim) {
size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < tensor_dim.size(); ++j) {
new_dims[4 - tensor_dim.size() + j] = tensor_dim[j];
......@@ -119,8 +118,7 @@ void CLImageConverterDefault::ImageToNCHW(half_t *image, float *tensor,
}
}
const DDim &CLImageConverterFolder::InitImageDimInfoWith(
const DDim &tensor_dim) {
DDim CLImageConverterFolder::InitImageDimInfoWith(const DDim &tensor_dim) {
if (tensor_dim.size() <= 2) {
int tdim[2] = {1, 1};
if (tensor_dim.size() == 1) {
......@@ -218,8 +216,7 @@ void CLImageConverterFolder::ImageToNCHW(half_t *image, float *tensor,
}
}
const DDim &CLImageConverterNWBlock::InitImageDimInfoWith(
const DDim &tensor_dim) {
DDim CLImageConverterNWBlock::InitImageDimInfoWith(const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4");
size_t N, C, H, W;
N = tensor_dim[0];
......@@ -297,8 +294,7 @@ void CLImageConverterNWBlock::ImageToNCHW(half_t *image, float *tensor,
DLOG << " init done";
}
const DDim &CLImageConverterDWBlock::InitImageDimInfoWith(
const DDim &tensor_dim) {
DDim CLImageConverterDWBlock::InitImageDimInfoWith(const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4");
size_t N, C, H, W;
N = tensor_dim[0];
......@@ -389,8 +385,7 @@ void CLImageConverterDWBlock::ImageToNCHW(half_t *image, float *tensor,
}
}
const DDim &CLImageConverterNormal::InitImageDimInfoWith(
const DDim &tensor_dim) {
DDim CLImageConverterNormal::InitImageDimInfoWith(const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() <= 4 && tensor_dim.size() > 0,
"tensor dim is not support ");
size_t new_dims[] = {1, 1, 1, 1};
......@@ -428,7 +423,7 @@ void CLImageConverterNormal::ImageToNCHW(half_t *image, float *tensor,
default_converter.ImageToNCHW(image, tensor, image_dim, tensor_dim);
}
const DDim &CLImageConverterWinoTransWeight::InitImageDimInfoWith(
DDim CLImageConverterWinoTransWeight::InitImageDimInfoWith(
const DDim &tensor_dim) {
PADDLE_MOBILE_ENFORCE(tensor_dim.size() == 4, " tensor dim is not 4");
size_t N, C, H, W;
......@@ -448,7 +443,7 @@ void CLImageConverterWinoTransWeight::ImageToNCHW(half_t *image, float *tensor,
const DDim &image_dim,
const DDim &tensor_dim) {}
const DDim &CLImageConverterConv2dTransposeTransWeight::InitImageDimInfoWith(
DDim CLImageConverterConv2dTransposeTransWeight::InitImageDimInfoWith(
const DDim &tensor_dim) {
size_t new_dims[] = {1, 1, 1, 1};
for (int j = 0; j < tensor_dim.size(); ++j) {
......
......@@ -27,12 +27,12 @@ class CLImageConverterBase {
virtual void ImageToNCHW(half_t *image, float *nchw, const DDim &image_dim,
const DDim &tensor_dim) = 0;
virtual const DDim &InitImageDimInfoWith(const DDim &tensor_dim) = 0;
virtual DDim InitImageDimInfoWith(const DDim &tensor_dim) = 0;
};
class CLImageConverterDefault : public CLImageConverterBase {
public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *nchw, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
......@@ -40,7 +40,7 @@ class CLImageConverterDefault : public CLImageConverterBase {
class CLImageConverterFolder : public CLImageConverterBase {
public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
......@@ -65,7 +65,7 @@ class CLImageConverterFolder : public CLImageConverterBase {
class CLImageConverterNormal : public CLImageConverterBase {
public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
......@@ -89,13 +89,13 @@ class CLImageConverterNormal : public CLImageConverterBase {
};
class CLImageConverterNWBlock : public CLImageConverterBase {
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
};
class CLImageConverterDWBlock : public CLImageConverterBase {
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
......@@ -103,7 +103,7 @@ class CLImageConverterDWBlock : public CLImageConverterBase {
class CLImageConverterWinoTransWeight : public CLImageConverterBase {
public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
......@@ -111,7 +111,7 @@ class CLImageConverterWinoTransWeight : public CLImageConverterBase {
class CLImageConverterConv2dTransposeTransWeight : public CLImageConverterBase {
public:
const DDim &InitImageDimInfoWith(const DDim &tensor_dim);
DDim InitImageDimInfoWith(const DDim &tensor_dim);
void NCHWToImage(float *tensor, half_t *image, const DDim &tensor_dim);
void ImageToNCHW(half_t *image, float *tensor, const DDim &image_dim,
const DDim &tensor_dim);
......
......@@ -62,7 +62,6 @@ struct DDim {
return vistor(d.var.Get<Dim<9>>());
} else {
PADDLE_MOBILE_ENFORCE(false, " dim not support");
exit(0);
}
}
......@@ -73,6 +72,8 @@ struct DDim {
var.Set<Dim<D>>(in);
}
DDim(const DDim &in) { setNewDim(in); }
/*implicit*/ DDim(std::initializer_list<int64_t> init_list);
template <int D>
......@@ -81,11 +82,42 @@ struct DDim {
return *this;
}
DDim &operator=(const DDim &in) {
setNewDim(in);
return *this;
}
void setNewDim(const DDim &d) {
if (d.var.TypeId() == type_id<Dim<0>>()) {
return var.Set<Dim<0>>(d.var.Get<Dim<0>>());
} else if (d.var.TypeId() == type_id<Dim<1>>()) {
return var.Set<Dim<1>>(d.var.Get<Dim<1>>());
} else if (d.var.TypeId() == type_id<Dim<2>>()) {
return var.Set<Dim<2>>(d.var.Get<Dim<2>>());
} else if (d.var.TypeId() == type_id<Dim<3>>()) {
return var.Set<Dim<3>>(d.var.Get<Dim<3>>());
} else if (d.var.TypeId() == type_id<Dim<4>>()) {
return var.Set<Dim<4>>(d.var.Get<Dim<4>>());
} else if (d.var.TypeId() == type_id<Dim<5>>()) {
return var.Set<Dim<5>>(d.var.Get<Dim<5>>());
} else if (d.var.TypeId() == type_id<Dim<6>>()) {
return var.Set<Dim<6>>(d.var.Get<Dim<6>>());
} else if (d.var.TypeId() == type_id<Dim<7>>()) {
return var.Set<Dim<7>>(d.var.Get<Dim<7>>());
} else if (d.var.TypeId() == type_id<Dim<8>>()) {
return var.Set<Dim<8>>(d.var.Get<Dim<8>>());
} else if (d.var.TypeId() == type_id<Dim<9>>()) {
return var.Set<Dim<9>>(d.var.Get<Dim<9>>());
} else {
PADDLE_MOBILE_ENFORCE(false, " dim not support");
}
}
int64_t &operator[](int idx);
int64_t operator[](int idx) const;
DDimVar getVar() { return var; }
DDimVar getVar() const { return var; }
bool operator==(DDim d) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册