提交 f565eb45 编写于 作者: qnqinan's avatar qnqinan

add tanh kernel in FPGA v1

上级 a35cbdff
......@@ -15,17 +15,61 @@ limitations under the License. */
#ifdef TANH_OP
#include "operators/kernel/tanh_kernel.h"
#include <math.h>
namespace paddle_mobile {
namespace operators {
template <>
bool TanhKernel<FPGA, float>::Init(TanhParam<FPGA> *param) {
auto input = const_cast<Tensor *>(param->InputX());
auto input_ptr = input->data<float>();
auto float_input = new Tensor;
float_input->mutable_data<float>(
{1, input->dims()[1], input->dims()[2], input->dims()[3]});
fpga::format_fp32_ofm(float_input);
fpga::BypassArgs args = {fpga::DATA_TYPE_FP16};
args.input_layout_type = fpga::LAYOUT_HWC;
args.output_layout_type = fpga::LAYOUT_CHW;
args.input_data_type = fpga::DATA_TYPE_FP16;
args.output_data_type = fpga::DATA_TYPE_FP32;
args.image.address = input_ptr;
args.image.height = (uint32_t)input->dims()[2];
args.image.width = (uint32_t)input->dims()[3];
args.image.channels = (uint32_t)input->dims()[1];
args.output.address = float_input->data<float>();
args.output.scale_address = float_input->scale;
param->SetFloatInput(float_input);
param->SetFpgaArgs(args);
return true;
}
#define EXP_MAX_INPUT 40.0
template <typename T>
T Tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
template <typename T>
void tanhFuntor(Tensor *input, Tensor *output) {
auto *input_ptr = input->data<T>();
auto *output_ptr = output->mutable_data<T>();
for (int i = 0; i < input->numel(); i++) {
*(output_ptr + i) = Tanh<T>(*(input_ptr + i));
}
}
template <>
void TanhKernel<FPGA, float>::Compute(const TanhParam<FPGA> &param) {}
void TanhKernel<FPGA, float>::Compute(const TanhParam<FPGA> &param) {
Tensor *in_x = param.FloatInput();
Tensor *out = param.Out();
fpga::PerformBypass(param.FpgaArgs());
fpga::fpga_invalidate((void *)in_x->data<float>(),
in_x->numel() * sizeof(float));
tanhFuntor<float>(in_x, out);
fpga::fpga_flush(out->data<float>(), out->memory_size());
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -1554,6 +1554,20 @@ class TanhParam : public OpParam {
private:
RType *input_x_;
RType *out_;
#ifdef PADDLE_MOBILE_FPGA
private:
std::shared_ptr<RType> float_input_x_;
fpga::BypassArgs fpga_bypass_args;
public:
RType *FloatInput() const {
return float_input_x_ == nullptr ? input_x_ : float_input_x_.get();
}
void SetFloatInput(Tensor *input) { float_input_x_.reset(input); }
const fpga::BypassArgs &FpgaArgs() const { return fpga_bypass_args; }
void SetFpgaArgs(const fpga::BypassArgs &args) { fpga_bypass_args = args; }
#endif
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册