提交 61647c35 编写于 作者: X Xiaoyang LI 提交者: Yan Chunwei

add workspace compute funcs for direct conv, test=develop (#2132)

上级 e122b4be
......@@ -23,6 +23,9 @@ namespace lite {
namespace arm {
namespace math {
/// conv 3x3s1
size_t conv3x3s1_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s1_direct_fp32(const float* din,
float* dout,
int num,
......@@ -53,6 +56,9 @@ void conv_3x3s1_direct_int8(const int8_t* din,
ARMContext* ctx,
const float* scale);
/// conv3x3s2
size_t conv3x3s2_direct_workspace_size(const operators::ConvParam& param,
ARMContext* ctx);
void conv_3x3s2_direct_fp32(const float* din,
float* dout,
int num,
......
......@@ -1104,13 +1104,13 @@ void DeviceInfo::SetCache(int l1size, int l2size, int l3size) {
SetCacheInfo(0, 1, l1size);
SetCacheInfo(1, 1, l2size);
SetCacheInfo(2, 1, l3size);
workspace_.Resize({2 * (l1size + l2size)});
workspace_.Resize({llc_size()});
workspace_.mutable_data<int8_t>();
}
bool DeviceInfo::ExtendWorkspace(int size) {
bool DeviceInfo::ExtendWorkspace(size_t size) {
workspace_.Resize({size + llc_size()});
workspace_.mutable_data<int8_t>();
return true;
return workspace_.mutable_data<int8_t>() != nullptr;
}
#endif // LITE_WITH_ARM
......
......@@ -73,7 +73,7 @@ class DeviceInfo {
T* workspace_data() {
return reinterpret_cast<T*>(workspace_.mutable_data<int8_t>());
}
bool ExtendWorkspace(int size);
bool ExtendWorkspace(size_t size);
private:
int core_num_;
......
......@@ -19,6 +19,25 @@ namespace lite {
namespace kernels {
namespace arm {
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
auto& param = this->template Param<param_t>();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
auto& ctx = this->ctx_->template As<ARMContext>();
if (param.strides[0] == 2) {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s2_direct_workspace_size(param, &ctx));
} else {
ctx.ExtendWorkspace(
lite::arm::math::conv3x3s1_direct_workspace_size(param, &ctx));
}
}
template <>
void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -70,6 +89,9 @@ void DirectConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::ReInitWhenNeeded() {}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto& param = this->Param<param_t>();
......@@ -126,6 +148,9 @@ void DirectConv<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
}
}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::ReInitWhenNeeded() {}
template <>
void DirectConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto& param = this->Param<param_t>();
......
......@@ -178,10 +178,12 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
w_scale_);
}
virtual void ReInitWhenNeeded();
virtual void Run();
/// todo, support inplace weights transform
protected:
DDim last_shape_;
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册