提交 f9907045 编写于 作者: S sunsuodong

refactor arithmic

上级 ac9b69ba
......@@ -23,39 +23,46 @@
#include "schema/model_generated.h"
namespace mindspore::kernel {
class ArithmeticFP16CPUKernel : public LiteKernel {
typedef int (*ArithmeticRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
typedef int (*ArithmeticOptRun)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
typedef int (*ArithmeticFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size);
typedef int (*ArithmeticOptFuncFp16)(float16_t *input0, float16_t *input1, float16_t *output, int element_size,
ArithmeticParameter *param);
typedef struct {
int primitive_type_;
int activation_type_;
ArithmeticFuncFp16 func_;
ArithmeticOptFuncFp16 opt_func_;
} ARITHMETIC_FUNC_INFO_FP16;
class ArithmeticFP16CPUKernel : public LiteKernel {
public:
ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
param_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~ArithmeticFP16CPUKernel() override;
~ArithmeticFP16CPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
int DoArithmetic(int task_id);
int BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int out_count,
int out_thread_stride);
int out_thread_stride);
private:
void FreeTmpBuffer();
int outside_;
int break_pos_;
int out_thread_stride_;
int out_count_;
bool is_input0_fp32_ = false;
bool is_input1_fp32_ = false;
bool is_output_fp32_ = false;
float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr;
ArithmeticParameter *arithmeticParameter_ = nullptr;
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
ArithmeticParameter *param_ = nullptr;
ArithmeticFuncFp16 arithmetic_func_ = nullptr;
ArithmeticOptFuncFp16 arithmetic_opt_func_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_ARITHMETIC_FP16_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册