提交 f2423c2d 编写于 作者: D dolphin8

merge with upstream

...@@ -9,7 +9,7 @@ option(LOG_PROFILE "log profile" ON) ...@@ -9,7 +9,7 @@ option(LOG_PROFILE "log profile" ON)
option(CPU "armv7 with neon" ON) option(CPU "armv7 with neon" ON)
option(MALI_GPU "mali gpu" OFF) option(MALI_GPU "mali gpu" OFF)
option(FPGA "fpga" OFF) option(FPGA "fpga" OFF)
set(DEBUGING ON)
if (CPU) if (CPU)
add_definitions(-DPADDLE_MOBILE_CPU) add_definitions(-DPADDLE_MOBILE_CPU)
endif() endif()
...@@ -28,7 +28,7 @@ set(CMAKE_CXX_FLAGS "-std=c++14 -O3 -s ${CMAKE_CXX_FLAGS}") ...@@ -28,7 +28,7 @@ set(CMAKE_CXX_FLAGS "-std=c++14 -O3 -s ${CMAKE_CXX_FLAGS}")
if (DEBUGING) if (DEBUGING)
message(STATUS "debug") message(STATUS "debug")
set(CMAKE_BUILD_TYPE Debug) set(CMAKE_BUILD_TYPE Debug)
set(CMAKE_CXX_FLAGS_DEBUG "-g") set(CMAKE_CXX_FLAGS_DEBUG "-g -DNDEBUG")
add_definitions(-DPADDLE_MOBILE_DEBUG) add_definitions(-DPADDLE_MOBILE_DEBUG)
if (ANDROID_NDK_TOOLCHAIN_INCLUDED) if (ANDROID_NDK_TOOLCHAIN_INCLUDED)
add_definitions(-DARMV7) add_definitions(-DARMV7)
...@@ -36,6 +36,7 @@ if (DEBUGING) ...@@ -36,6 +36,7 @@ if (DEBUGING)
endif () endif ()
else () else ()
set(CMAKE_BUILD_TYPE Release) set(CMAKE_BUILD_TYPE Release)
set(CMAKE_CXX_FLAGS_RELEASE "-DNDEBUG")
add_definitions(-fvisibility=hidden -fvisibility-inlines-hidden) add_definitions(-fvisibility=hidden -fvisibility-inlines-hidden)
endif () endif ()
......
...@@ -63,6 +63,7 @@ class OperatorBase { ...@@ -63,6 +63,7 @@ class OperatorBase {
std::vector<string> GetOutKeys() const; std::vector<string> GetOutKeys() const;
virtual void RunImpl() const = 0; virtual void RunImpl() const = 0;
virtual void Init() const = 0;
/* /*
* @b op 运算所需的输入, 如上一层的输出结果、卷积核 * @b op 运算所需的输入, 如上一层的输出结果、卷积核
* */ * */
...@@ -110,15 +111,17 @@ class OperatorWithKernel : public OperatorBase<Dtype> { ...@@ -110,15 +111,17 @@ class OperatorWithKernel : public OperatorBase<Dtype> {
const VariableNameMap &outputs, const AttributeMap &attrs, const VariableNameMap &outputs, const AttributeMap &attrs,
std::shared_ptr<Scope> scope) std::shared_ptr<Scope> scope)
: OperatorBase<Dtype>(type, inputs, outputs, attrs, scope), : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
param_(inputs, outputs, attrs, *scope) { param_(inputs, outputs, attrs, *scope) {}
PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed",
this->type_.c_str());
}
virtual void RunImpl() const { this->kernel_.Compute(this->param_); } virtual void RunImpl() const { this->kernel_.Compute(this->param_); }
virtual void InferShape() const = 0; virtual void InferShape() const = 0;
void Init() const {
PADDLE_MOBILE_ENFORCE(kernel_.Init(param_), " %s kernel init failed",
this->type_.c_str());
}
protected: protected:
KernelType kernel_; KernelType kernel_;
ParamType param_; ParamType param_;
......
...@@ -198,6 +198,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size, ...@@ -198,6 +198,13 @@ Executor<Dtype, P>::Executor(const framework::Program<Dtype> p, int batch_size,
} else { } else {
InitMemory(); InitMemory();
} }
std::shared_ptr<framework::BlockDesc> to_predict_block =
to_predict_program_->Block(0);
auto &ops = ops_of_block_[*to_predict_block.get()];
for (const auto &op : ops) {
op->Init();
}
} }
template <typename Dtype, Precision P> template <typename Dtype, Precision P>
...@@ -416,6 +423,8 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict( ...@@ -416,6 +423,8 @@ std::shared_ptr<framework::Tensor> Executor<Dtype, P>::Predict(
clock_gettime(CLOCK_MONOTONIC, &ts); clock_gettime(CLOCK_MONOTONIC, &ts);
profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; profile[i].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif #endif
// to Run
ops[i]->Run(); ops[i]->Run();
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
clock_gettime(CLOCK_MONOTONIC, &ts); clock_gettime(CLOCK_MONOTONIC, &ts);
......
...@@ -32,6 +32,8 @@ class FeedOp : public framework::OperatorBase<DeviceType> { ...@@ -32,6 +32,8 @@ class FeedOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {} param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void InferShape() const { void InferShape() const {
auto out_dims = param_.Out()->dims(); auto out_dims = param_.Out()->dims();
out_dims[0] = param_.BatchSize(); out_dims[0] = param_.BatchSize();
......
...@@ -33,6 +33,8 @@ class FetchOp : public framework::OperatorBase<DeviceType> { ...@@ -33,6 +33,8 @@ class FetchOp : public framework::OperatorBase<DeviceType> {
param_(inputs, outputs, attrs, *scope) {} param_(inputs, outputs, attrs, *scope) {}
void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); }
void Init() const {}
void InferShape() const { void InferShape() const {
auto x_dims = param_.InputX()->dims(); auto x_dims = param_.InputX()->dims();
param_.Out()->Resize(x_dims); param_.Out()->Resize(x_dims);
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool BatchNormKernel<CPU, float>::Init(const BatchNormParam &para) const {
return true;
}
template <> template <>
void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const { void BatchNormKernel<CPU, float>::Compute(const BatchNormParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -110,6 +110,11 @@ void DecodeCenterSize(const framework::Tensor& target_box, ...@@ -110,6 +110,11 @@ void DecodeCenterSize(const framework::Tensor& target_box,
} }
} }
template <>
bool BoxCoderKernel<CPU, float>::Init(const BoxCoderParam& para) const {
return true;
}
template <> template <>
void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam& param) const { void BoxCoderKernel<CPU, float>::Compute(const BoxCoderParam& param) const {
const auto* input_priorbox = param.InputPriorBox(); const auto* input_priorbox = param.InputPriorBox();
......
...@@ -52,6 +52,11 @@ class ConcatFunctor { ...@@ -52,6 +52,11 @@ class ConcatFunctor {
} }
}; };
template <>
bool ConcatKernel<CPU, float>::Init(const ConcatParam &para) const {
return true;
}
template <> template <>
void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const { void ConcatKernel<CPU, float>::Compute(const ConcatParam &param) const {
auto inputs = param.Inputs(); auto inputs = param.Inputs();
......
...@@ -18,6 +18,11 @@ limitations under the License. */ ...@@ -18,6 +18,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvAddKernel<CPU, float>::Init(const FusionConvAddParam &para) const {
return true;
}
template <> template <>
void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const { void ConvAddKernel<CPU, float>::Compute(const FusionConvAddParam &param) const {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
......
...@@ -19,6 +19,12 @@ limitations under the License. */ ...@@ -19,6 +19,12 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvAddReluKernel<CPU, float>::Init(
const FusionConvAddReluParam &para) const {
return true;
}
template <> template <>
void ConvAddReluKernel<CPU, float>::Compute( void ConvAddReluKernel<CPU, float>::Compute(
const FusionConvAddReluParam &param) const { const FusionConvAddReluParam &param) const {
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvKernel<CPU, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void ConvKernel<CPU, float>::Compute(const ConvParam &param) const { void ConvKernel<CPU, float>::Compute(const ConvParam &param) const {
const Tensor *input = param.Input(); const Tensor *input = param.Input();
......
...@@ -20,6 +20,11 @@ limitations under the License. */ ...@@ -20,6 +20,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool DepthwiseConvKernel<CPU, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const { void DepthwiseConvKernel<CPU, float>::Compute(const ConvParam &param) const {
LOG(kLOG_DEBUG) << param; LOG(kLOG_DEBUG) << param;
......
...@@ -26,6 +26,12 @@ struct AddFunctor { ...@@ -26,6 +26,12 @@ struct AddFunctor {
inline T operator()(T a, T b) const { return a + b; } inline T operator()(T a, T b) const { return a + b; }
}; };
template <>
bool ElementwiseAddKernel<CPU, float>::Init(
const ElementwiseAddParam &para) const {
return true;
}
template <> template <>
void ElementwiseAddKernel<CPU, float>::Compute( void ElementwiseAddKernel<CPU, float>::Compute(
const ElementwiseAddParam &param) const { const ElementwiseAddParam &param) const {
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool FusionFcKernel<CPU, float>::Init(const FusionFcParam &para) const {
return true;
}
template <> template <>
void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const { void FusionFcKernel<CPU, float>::Compute(const FusionFcParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool LrnKernel<CPU, float>::Init(const LrnParam &para) const {
return true;
}
template <> template <>
void LrnKernel<CPU, float>::Compute(const LrnParam &param) const { void LrnKernel<CPU, float>::Compute(const LrnParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -21,6 +21,11 @@ limitations under the License. */ ...@@ -21,6 +21,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool MulKernel<CPU, float>::Init(const MulParam &para) const {
return true;
}
template <> template <>
void MulKernel<CPU, float>::Compute(const MulParam &param) const { void MulKernel<CPU, float>::Compute(const MulParam &param) const {
const Tensor *input_x = param.InputX(); const Tensor *input_x = param.InputX();
......
...@@ -203,6 +203,12 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes, ...@@ -203,6 +203,12 @@ void MultiClassOutput(const Tensor& scores, const Tensor& bboxes,
} }
} }
template <>
bool MultiClassNMSKernel<CPU, float>::Init(
const MultiClassNMSParam& para) const {
return true;
}
template <> template <>
void MultiClassNMSKernel<CPU, float>::Compute( void MultiClassNMSKernel<CPU, float>::Compute(
const MultiClassNMSParam& param) const { const MultiClassNMSParam& param) const {
......
...@@ -35,6 +35,11 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize, ...@@ -35,6 +35,11 @@ inline void PoolBasic(std::string pooling_type, std::vector<int> ksize,
} }
} }
template <>
bool PoolKernel<CPU, float>::Init(const PoolParam &para) const {
return true;
}
template <> template <>
void PoolKernel<CPU, float>::Compute(const PoolParam &param) const { void PoolKernel<CPU, float>::Compute(const PoolParam &param) const {
const Tensor *in_x = param.Input(); const Tensor *in_x = param.Input();
......
...@@ -26,6 +26,11 @@ struct ClipFunctor { ...@@ -26,6 +26,11 @@ struct ClipFunctor {
} }
}; };
template <>
bool PriorBoxKernel<CPU, float>::Init(const PriorBoxParam &para) const {
return true;
}
template <> template <>
void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const { void PriorBoxKernel<CPU, float>::Compute(const PriorBoxParam &param) const {
const auto *input_ = param.Input(); const auto *input_ = param.Input();
......
...@@ -25,6 +25,11 @@ struct ReluFunctor { ...@@ -25,6 +25,11 @@ struct ReluFunctor {
inline T operator()(T in) const { return in > 0 ? in : 0; } inline T operator()(T in) const { return in > 0 ? in : 0; }
}; };
template <>
bool ReluKernel<CPU, float>::Init(const ReluParam &para) const {
return true;
}
/* /*
* @b 特化到具体平台的实现, param 从 op 层传入 * @b 特化到具体平台的实现, param 从 op 层传入
* */ * */
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ReshapeKernel<CPU, float>::Init(const ReshapeParam &para) const {
return true;
}
template <> template <>
void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const { void ReshapeKernel<CPU, float>::Compute(const ReshapeParam &param) const {
const auto *input_x = param.InputX(); const auto *input_x = param.InputX();
......
...@@ -71,6 +71,11 @@ void sigmoid(const Tensor *X, Tensor *Y) { ...@@ -71,6 +71,11 @@ void sigmoid(const Tensor *X, Tensor *Y) {
#endif #endif
} }
template <>
bool SigmoidKernel<CPU, float>::Init(const SigmoidParam &para) const {
return true;
}
template <> template <>
void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const { void SigmoidKernel<CPU, float>::Compute(const SigmoidParam &param) const {
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.InputX();
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool SoftmaxKernel<CPU, float>::Init(const SoftmaxParam &para) const {
return true;
}
template <> template <>
void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const { void SoftmaxKernel<CPU, float>::Compute(const SoftmaxParam &param) const {
const Tensor *in_x = param.InputX(); const Tensor *in_x = param.InputX();
......
...@@ -34,6 +34,11 @@ namespace operators { ...@@ -34,6 +34,11 @@ namespace operators {
// } // }
// } // }
template <>
bool TransposeKernel<CPU, float>::Init(const TransposeParam& para) const {
return true;
}
template <> template <>
void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const { void TransposeKernel<CPU, float>::Compute(const TransposeParam& param) const {
const auto* input_x = param.InputX(); const auto* input_x = param.InputX();
......
...@@ -29,6 +29,7 @@ class BatchNormKernel ...@@ -29,6 +29,7 @@ class BatchNormKernel
: public framework::OpKernelBase<DeviceType, BatchNormParam> { : public framework::OpKernelBase<DeviceType, BatchNormParam> {
public: public:
void Compute(const BatchNormParam &param) const; void Compute(const BatchNormParam &param) const;
bool Init(const BatchNormParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -30,6 +30,7 @@ class BoxCoderKernel ...@@ -30,6 +30,7 @@ class BoxCoderKernel
: public framework::OpKernelBase<DeviceType, BoxCoderParam> { : public framework::OpKernelBase<DeviceType, BoxCoderParam> {
public: public:
void Compute(const BoxCoderParam& param) const; void Compute(const BoxCoderParam& param) const;
bool Init(const BoxCoderParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T> ...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T>
class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> { class ConcatKernel : public framework::OpKernelBase<DeviceType, ConcatParam> {
public: public:
void Compute(const ConcatParam &param) const; void Compute(const ConcatParam &param) const;
bool Init(const ConcatParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -38,6 +38,7 @@ template <typename DeviceType, typename T> ...@@ -38,6 +38,7 @@ template <typename DeviceType, typename T>
class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> { class ConvAddKernel : public OpKernelBase<DeviceType, FusionConvAddParam> {
public: public:
void Compute(const FusionConvAddParam &param) const; void Compute(const FusionConvAddParam &param) const;
bool Init(const FusionConvAddParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -36,6 +36,7 @@ class ConvAddReluKernel ...@@ -36,6 +36,7 @@ class ConvAddReluKernel
: public OpKernelBase<DeviceType, FusionConvAddReluParam> { : public OpKernelBase<DeviceType, FusionConvAddReluParam> {
public: public:
void Compute(const FusionConvAddReluParam &param) const; void Compute(const FusionConvAddReluParam &param) const;
bool Init(const FusionConvAddReluParam &para) const;
}; };
} // namespace operators } // namespace operators
......
...@@ -32,6 +32,7 @@ template <typename DeviceType, typename T> ...@@ -32,6 +32,7 @@ template <typename DeviceType, typename T>
class ConvKernel : public OpKernelBase<DeviceType, ConvParam> { class ConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
}; };
inline bool IsExpand(const std::vector<int64_t> &filter_dim, inline bool IsExpand(const std::vector<int64_t> &filter_dim,
......
...@@ -31,6 +31,7 @@ template <typename DeviceType, typename T> ...@@ -31,6 +31,7 @@ template <typename DeviceType, typename T>
class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> { class DepthwiseConvKernel : public OpKernelBase<DeviceType, ConvParam> {
public: public:
void Compute(const ConvParam &param) const; void Compute(const ConvParam &param) const;
bool Init(const ConvParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -30,6 +30,7 @@ class ElementwiseAddKernel ...@@ -30,6 +30,7 @@ class ElementwiseAddKernel
: public framework::OpKernelBase<DeviceType, ElementwiseAddParam> { : public framework::OpKernelBase<DeviceType, ElementwiseAddParam> {
public: public:
void Compute(const ElementwiseAddParam &param) const; void Compute(const ElementwiseAddParam &param) const;
bool Init(const ElementwiseAddParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvKernel<FPGA, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void ConvKernel<FPGA, float>::Compute(const ConvParam &param) const {} void ConvKernel<FPGA, float>::Compute(const ConvParam &param) const {}
template class ConvKernel<FPGA, float>; template class ConvKernel<FPGA, float>;
......
...@@ -28,6 +28,7 @@ class FusionFcKernel ...@@ -28,6 +28,7 @@ class FusionFcKernel
: public framework::OpKernelBase<DeviceType, FusionFcParam> { : public framework::OpKernelBase<DeviceType, FusionFcParam> {
public: public:
void Compute(const FusionFcParam& param) const; void Compute(const FusionFcParam& param) const;
bool Init(const FusionFcParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -170,6 +170,7 @@ template <typename DeviceType, typename T> ...@@ -170,6 +170,7 @@ template <typename DeviceType, typename T>
class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> { class LrnKernel : public framework::OpKernelBase<DeviceType, LrnParam> {
public: public:
void Compute(const LrnParam &param) const; void Compute(const LrnParam &param) const;
bool Init(const LrnParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -20,6 +20,11 @@ limitations under the License. */ ...@@ -20,6 +20,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool BatchNormKernel<GPU_MALI, float>::Init(const BatchNormParam &para) const {
return true;
}
template <> template <>
void BatchNormKernel<GPU_MALI, float>::Compute( void BatchNormKernel<GPU_MALI, float>::Compute(
const BatchNormParam &param) const {} const BatchNormParam &param) const {}
......
...@@ -19,6 +19,11 @@ limitations under the License. */ ...@@ -19,6 +19,11 @@ limitations under the License. */
namespace paddle_mobile { namespace paddle_mobile {
namespace operators { namespace operators {
template <>
bool ConvKernel<GPU_MALI, float>::Init(const ConvParam &para) const {
return true;
}
template <> template <>
void ConvKernel<GPU_MALI, float>::Compute(const ConvParam &param) const { void ConvKernel<GPU_MALI, float>::Compute(const ConvParam &param) const {
// ArmConvImplement imp; // ArmConvImplement imp;
......
...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T> ...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T>
class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> { class MulKernel : public framework::OpKernelBase<DeviceType, MulParam> {
public: public:
void Compute(const MulParam &param) const; void Compute(const MulParam &param) const;
bool Init(const MulParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -28,6 +28,7 @@ class MultiClassNMSKernel ...@@ -28,6 +28,7 @@ class MultiClassNMSKernel
: public framework::OpKernelBase<DeviceType, MultiClassNMSParam> { : public framework::OpKernelBase<DeviceType, MultiClassNMSParam> {
public: public:
void Compute(const MultiClassNMSParam& param) const; void Compute(const MultiClassNMSParam& param) const;
bool Init(const MultiClassNMSParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -28,6 +28,7 @@ template <typename DeviceType, typename T> ...@@ -28,6 +28,7 @@ template <typename DeviceType, typename T>
class PoolKernel : public OpKernelBase<DeviceType, PoolParam> { class PoolKernel : public OpKernelBase<DeviceType, PoolParam> {
public: public:
void Compute(const PoolParam &param) const override; void Compute(const PoolParam &param) const override;
bool Init(const PoolParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -55,6 +55,7 @@ class PriorBoxKernel ...@@ -55,6 +55,7 @@ class PriorBoxKernel
: public framework::OpKernelBase<DeviceType, PriorBoxParam> { : public framework::OpKernelBase<DeviceType, PriorBoxParam> {
public: public:
void Compute(const PriorBoxParam& param) const; void Compute(const PriorBoxParam& param) const;
bool Init(const PriorBoxParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T> ...@@ -27,6 +27,7 @@ template <typename DeviceType, typename T>
class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> { class ReluKernel : public framework::OpKernelBase<DeviceType, ReluParam> {
public: public:
void Compute(const ReluParam& param) const; void Compute(const ReluParam& param) const;
bool Init(const ReluParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -71,6 +71,7 @@ template <typename DeviceType, typename T> ...@@ -71,6 +71,7 @@ template <typename DeviceType, typename T>
class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> { class ReshapeKernel : public framework::OpKernelBase<DeviceType, ReshapeParam> {
public: public:
void Compute(const ReshapeParam& param) const; void Compute(const ReshapeParam& param) const;
bool Init(const ReshapeParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -26,6 +26,7 @@ template <typename DeviceType, typename T> ...@@ -26,6 +26,7 @@ template <typename DeviceType, typename T>
class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> { class SigmoidKernel : public OpKernelBase<DeviceType, SigmoidParam> {
public: public:
void Compute(const SigmoidParam& param) const override; void Compute(const SigmoidParam& param) const override;
bool Init(const SigmoidParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T> ...@@ -29,6 +29,7 @@ template <typename DeviceType, typename T>
class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> { class SoftmaxKernel : public OpKernelBase<DeviceType, SoftmaxParam> {
public: public:
void Compute(const SoftmaxParam &param) const override; void Compute(const SoftmaxParam &param) const override;
bool Init(const SoftmaxParam &para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -29,6 +29,7 @@ class TransposeKernel ...@@ -29,6 +29,7 @@ class TransposeKernel
: public framework::OpKernelBase<DeviceType, TransposeParam> { : public framework::OpKernelBase<DeviceType, TransposeParam> {
public: public:
void Compute(const TransposeParam& param) const; void Compute(const TransposeParam& param) const;
bool Init(const TransposeParam& para) const;
}; };
} // namespace operators } // namespace operators
} // namespace paddle_mobile } // namespace paddle_mobile
......
...@@ -26,12 +26,12 @@ alignas(64) float packedA[MC * KC]; ...@@ -26,12 +26,12 @@ alignas(64) float packedA[MC * KC];
alignas(64) float packedB[KC * NC]; alignas(64) float packedB[KC * NC];
alignas(64) float ab[MR * NR]; alignas(64) float ab[MR * NR];
// 将A矩阵分块复制到连续内存(ColMajor) // 将A矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int paddingM, const float *A, int lda, void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
int i, j; int i, j;
const float *Aij; const float *Aij;
for (i = 0; i < m - paddingM; i += MR) { for (i = 0; i < m - m_tail; i += MR) {
for (int j = 0; j < k; ++j) { for (j = 0; j < k; ++j) {
Aij = &A(i, j); Aij = &A(i, j);
*buffer++ = *Aij; *buffer++ = *Aij;
*buffer++ = *(Aij + 1); *buffer++ = *(Aij + 1);
...@@ -39,13 +39,13 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda, ...@@ -39,13 +39,13 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
*buffer++ = *(Aij + 3); *buffer++ = *(Aij + 3);
} }
} }
if (paddingM != 0) { if (m_tail != 0) {
for (j = 0; j < k; ++j) { for (j = 0; j < k; ++j) {
Aij = &A(m - paddingM, j); Aij = &A(m - m_tail, j);
for (i = 0; i < paddingM; ++i) { for (i = 0; i < m_tail; ++i) {
*buffer++ = *(Aij + i); *buffer++ = *(Aij + i);
} }
for (i = paddingM; i < MR; ++i) { for (i = m_tail; i < MR; ++i) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -53,11 +53,11 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda, ...@@ -53,11 +53,11 @@ void PackMatrixA(int m, int k, int paddingM, const float *A, int lda,
} }
// 将A矩阵分块复制到连续内存(RowMajor) // 将A矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda, void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer) { float *buffer) {
int i, j; int i, j;
const float *Ai, *Ai1, *Ai2, *Ai3; const float *Ai, *Ai1, *Ai2, *Ai3;
for (i = 0; i < m - paddingM; i += MR) { for (i = 0; i < m - m_tail; i += MR) {
Ai = &A(i, 0); Ai = &A(i, 0);
Ai1 = &A(i + 1, 0); Ai1 = &A(i + 1, 0);
Ai2 = &A(i + 2, 0); Ai2 = &A(i + 2, 0);
...@@ -69,12 +69,12 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda, ...@@ -69,12 +69,12 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda,
*buffer++ = *Ai3++; *buffer++ = *Ai3++;
} }
} }
if (paddingM != 0) { if (m_tail != 0) {
for (j = 0; j < k; ++j) { for (j = 0; j < k; ++j) {
for (i = m - paddingM; i < m; ++i) { for (i = m - m_tail; i < m; ++i) {
*buffer++ = A(i, j); *buffer++ = A(i, j);
} }
for (i = m; i < m + (MR - paddingM); ++i) { for (i = m; i < m + (MR - m_tail); ++i) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -82,11 +82,11 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda, ...@@ -82,11 +82,11 @@ void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda,
} }
// 将B矩阵分块复制到连续内存(ColMajor) // 将B矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb, void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
int i, j; int i, j;
const float *Bj, *Bj1, *Bj2, *Bj3; const float *Bj, *Bj1, *Bj2, *Bj3;
for (j = 0; j < n - paddingN; j += NR) { for (j = 0; j < n - n_tail; j += NR) {
Bj = &B(0, j); Bj = &B(0, j);
Bj1 = &B(0, j + 1); Bj1 = &B(0, j + 1);
Bj2 = &B(0, j + 2); Bj2 = &B(0, j + 2);
...@@ -98,12 +98,12 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb, ...@@ -98,12 +98,12 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb,
*buffer++ = *Bj3++; *buffer++ = *Bj3++;
} }
} }
if (paddingN != 0) { if (n_tail != 0) {
for (i = 0; i < k; ++i) { for (i = 0; i < k; ++i) {
for (int j = n - paddingN; j < n; ++j) { for (int j = n - n_tail; j < n; ++j) {
*buffer++ = B(i, j); *buffer++ = B(i, j);
} }
for (int j = n; j < n + (NR - paddingN); ++j) { for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -111,11 +111,11 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb, ...@@ -111,11 +111,11 @@ void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb,
} }
// 将B矩阵分块复制到连续内存(RowMajor) // 将B矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer) { float *buffer) {
int i, j; int i, j;
const float *Bij; const float *Bij;
for (j = 0; j < n - paddingN; j += NR) { for (j = 0; j < n - n_tail; j += NR) {
for (i = 0; i < k; ++i) { for (i = 0; i < k; ++i) {
Bij = &B(i, j); Bij = &B(i, j);
asm volatile( asm volatile(
...@@ -126,13 +126,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, ...@@ -126,13 +126,13 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
: "memory", "q0"); : "memory", "q0");
} }
} }
if (paddingN != 0) { if (n_tail != 0) {
for (i = 0; i < k; ++i) { for (i = 0; i < k; ++i) {
Bij = &B(i, n - paddingN); Bij = &B(i, n - n_tail);
for (int j = n - paddingN; j < n; ++j) { for (int j = n - n_tail; j < n; ++j) {
*buffer++ = *Bij++; *buffer++ = *Bij++;
} }
for (int j = n; j < n + (NR - paddingN); ++j) { for (int j = n; j < n + (NR - n_tail); ++j) {
*buffer++ = 0; *buffer++ = 0;
} }
} }
...@@ -143,33 +143,25 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, ...@@ -143,33 +143,25 @@ void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb,
void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
int first_time) { int first_time) {
int Buff_A_M = m; int m_block = (m + MR - 1) / MR * MR;
int Buff_B_N = n; int n_block = (n + NR - 1) / NR * NR;
int _mc = m % MR; int m_tail = m % MR;
int _nc = n % NR; int n_tail = n % NR;
if (_mc != 0) {
Buff_A_M = m + (MR - _mc);
}
if (_nc != 0) {
Buff_B_N = n + (NR - _nc);
}
if (first_time) { if (first_time) {
PackMatrixB_(k, n, _nc, B, ldb, packedB); PackMatrixB_(k, n, n_tail, B, ldb, packedB);
} }
PackMatrixA_(m, k, _mc, A, lda, packedA); PackMatrixA_(m, k, m_tail, A, lda, packedA);
int i, j, mc, nc; int i, j, mc, nc;
// B 取 4 列, 打包预热 // B 取 4 列, 打包预热
for (j = 0; j < Buff_B_N; j += NR) { for (j = 0; j < n_block; j += NR) {
nc = (n - j) < NR ? _nc : NR; nc = (n - j) < NR ? n_tail : NR;
// A 取 4 行,打包预热 // A 取 4 行,打包预热
for (i = 0; i < Buff_A_M; i += MR) { for (i = 0; i < m_block; i += MR) {
mc = (m - i) < MR ? _mc : MR; mc = (m - i) < MR ? m_tail : MR;
AddDot4x4(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta, AddDot4x4(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta,
&C(i, j), ldc, mc, nc); &C(i, j), ldc, mc, nc);
} }
...@@ -180,36 +172,25 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda, ...@@ -180,36 +172,25 @@ void InnerKernel(int m, int n, int k, float alpha, const float *A, int lda,
void InnerKernel_relu(int m, int n, int k, float alpha, const float *A, int lda, void InnerKernel_relu(int m, int n, int k, float alpha, const float *A, int lda,
const float *B, int ldb, float beta, float *C, int ldc, const float *B, int ldb, float beta, float *C, int ldc,
int first_time, bool relu = false) { int first_time, bool relu = false) {
int Buff_A_M = m; int m_block = (m + MR - 1) / MR * MR;
int Buff_B_N = n; int n_block = (n + NR - 1) / NR * NR;
int _mc = m % MR;
int _nc = n % NR;
if (_mc != 0) { int m_tail = m % MR;
Buff_A_M = m + (MR - _mc); int n_tail = n % NR;
}
if (_nc != 0) {
Buff_B_N = n + (NR - _nc);
}
float packedA[MC * KC];
static float packedB[KC * NC];
if (first_time) { if (first_time) {
PackMatrixB_(k, n, _nc, B, ldb, packedB); PackMatrixB_(k, n, n_tail, B, ldb, packedB);
} }
PackMatrixA_(m, k, _mc, A, lda, packedA); PackMatrixA_(m, k, m_tail, A, lda, packedA);
int i, j, mc, nc; int i, j, mc, nc;
// B 取 4 列, 打包预热 // B 取 4 列, 打包预热
for (j = 0; j < Buff_B_N; j += NR) { for (j = 0; j < n_block; j += NR) {
nc = (n - j) < NR ? _nc : NR; nc = (n - j) < NR ? n_tail : NR;
// A 取 4 行,打包预热 // A 取 4 行,打包预热
for (i = 0; i < Buff_A_M; i += MR) { for (i = 0; i < m_block; i += MR) {
mc = (m - i) < MR ? _mc : MR; mc = (m - i) < MR ? m_tail : MR;
AddDot4x4_relu(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta, AddDot4x4_relu(k, alpha, &packedA[i * k], 4, &packedB[j * k], k, beta,
&C(i, j), ldc, mc, nc, relu); &C(i, j), ldc, mc, nc, relu);
} }
...@@ -375,12 +356,15 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b, ...@@ -375,12 +356,15 @@ void AddDot4x4(int k, float alpha, const float *a, int lda, const float *b,
"subs %[kc2], %[kc2], #1 \n\t" "subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t" "blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a]]! \n\t" "vld1.32 {q0}, [%[a]]! \n\t"
"vld1.32 {q1}, [%[b]]! \n\t" "vld1.32 {q1}, [%[b]]! \n\t"
"vmla.f32 q10, q1, d0[0] \n\t" "vmla.f32 q10, q1, d0[0] \n\t"
"vmla.f32 q11, q1, d0[1] \n\t" "vmla.f32 q11, q1, d0[1] \n\t"
"vmla.f32 q12, q1, d1[0] \n\t" "vmla.f32 q12, q1, d1[0] \n\t"
"vmla.f32 q13, q1, d1[1] \n\t" "vmla.f32 q13, q1, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t" "end_kc2_%=: \n\t"
"cmp %[mc], #4 \n\t" "cmp %[mc], #4 \n\t"
...@@ -525,12 +509,15 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, ...@@ -525,12 +509,15 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
"subs %[kc2], %[kc2], #1 \n\t" "subs %[kc2], %[kc2], #1 \n\t"
"blt end_kc2_%= \n\t" "blt end_kc2_%= \n\t"
"loop_kc2_%=: \n\t"
"vld1.32 {q0}, [%[a]]! \n\t" "vld1.32 {q0}, [%[a]]! \n\t"
"vld1.32 {q1}, [%[b]]! \n\t" "vld1.32 {q1}, [%[b]]! \n\t"
"vmla.f32 q10, q1, d0[0] \n\t" "vmla.f32 q10, q1, d0[0] \n\t"
"vmla.f32 q11, q1, d0[1] \n\t" "vmla.f32 q11, q1, d0[1] \n\t"
"vmla.f32 q12, q1, d1[0] \n\t" "vmla.f32 q12, q1, d1[0] \n\t"
"vmla.f32 q13, q1, d1[1] \n\t" "vmla.f32 q13, q1, d1[1] \n\t"
"subs %[kc2], %[kc2], #1 \n\t"
"bge loop_kc2_%= \n\t"
"end_kc2_%=: \n\t" "end_kc2_%=: \n\t"
"cmp %[mc], #4 \n\t" "cmp %[mc], #4 \n\t"
...@@ -599,7 +586,8 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b, ...@@ -599,7 +586,8 @@ void AddDot4x4_relu(int k, float alpha, const float *a, int lda, const float *b,
[kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha), [kc2] "r"(kc2), [mc] "r"(mc), [nc] "r"(nc), [alpha] "r"(alpha),
[beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc), [beta] "r"(beta), [bytes_ldc] "r"(bytes_ldc),
[flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta) [flag_alpha] "r"(flag_alpha), [flag_beta] "r"(flag_beta)
: "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13"); : "memory", "q0", "q1", "q2", "q3", "q4", "q10", "q11", "q12", "q13",
"q14");
if (mc != MR || nc != NR) { if (mc != MR || nc != NR) {
int i, j; int i, j;
......
...@@ -33,19 +33,19 @@ namespace operators { ...@@ -33,19 +33,19 @@ namespace operators {
namespace math { namespace math {
// 将 A 矩阵分块复制到连续内存(ColMajor) // 将 A 矩阵分块复制到连续内存(ColMajor)
void PackMatrixA(int m, int k, int paddingM, const float *A, int lda, void PackMatrixA(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
// 将 B 矩阵分块复制到连续内存(ColMajor) // 将 B 矩阵分块复制到连续内存(ColMajor)
void PackMatrixB(int k, int n, int paddingN, const float *B, int ldb, void PackMatrixB(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
// 将 A 矩阵分块复制到连续内存(RowMajor) // 将 A 矩阵分块复制到连续内存(RowMajor)
void PackMatrixA_(int m, int k, int paddingM, const float *A, int lda, void PackMatrixA_(int m, int k, int m_tail, const float *A, int lda,
float *buffer); float *buffer);
// 将 B 矩阵分块复制到连续内存(RowMajor) // 将 B 矩阵分块复制到连续内存(RowMajor)
void PackMatrixB_(int k, int n, int paddingN, const float *B, int ldb, void PackMatrixB_(int k, int n, int n_tail, const float *B, int ldb,
float *buffer); float *buffer);
// 分块矩阵乘法 // 分块矩阵乘法
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iostream> #include <iostream>
#include "../test_helper.h"
#include "common/log.h" #include "common/log.h"
#include "memory/t_malloc.h"
#include "operators/math/gemm.h" #include "operators/math/gemm.h"
#define a(i, j) a[(i)*lda + (j)] #define a(i, j) a[(i)*lda + (j)]
...@@ -29,10 +31,15 @@ int main() { ...@@ -29,10 +31,15 @@ int main() {
int ldb = n; int ldb = n;
int ldc = n; int ldc = n;
float a[62 * 74]; float *a =
float b[74 * 63]; static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * k));
float c[62 * 63] = {0}; float *b =
float c1[62 * 63] = {0}; static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * k * n));
float *c =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
float *c1 =
static_cast<float *>(paddle_mobile::memory::Alloc(sizeof(float) * m * n));
for (int i = 0; i < m * k; ++i) { for (int i = 0; i < m * k; ++i) {
a[i] = 2; a[i] = 2;
} }
...@@ -44,8 +51,11 @@ int main() { ...@@ -44,8 +51,11 @@ int main() {
c1[i] = 2; c1[i] = 2;
} }
auto time1 = time();
paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c, paddle_mobile::operators::math::sgemm(m, n, k, 0.9, a, lda, b, ldb, 0.3, c,
ldc); ldc);
auto time2 = time();
DLOG << "gemm cost :" << time_diff(time1, time2) << "ms\n";
for (int i = 0; i < m * n; ++i) { for (int i = 0; i < m * n; ++i) {
std::cout << c[i] << " | "; std::cout << c[i] << " | ";
if (i % n == (n - 1)) { if (i % n == (n - 1)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册