未验证 提交 dfce4621 编写于 作者: J juncaipeng 提交者: GitHub

add rsqrt op, test=develop (#2176)

上级 77811367
......@@ -127,3 +127,4 @@ USE_LITE_OP(roi_align)
USE_LITE_OP(box_clip)
USE_LITE_OP(assign_value)
USE_LITE_OP(hard_sigmoid)
USE_LITE_OP(rsqrt)
......@@ -688,6 +688,18 @@ void act_hard_sigmoid<float>(const float* din,
++dout;
}
}
template <>
void act_rsqrt<float>(const float* din, float* dout, int size, int threads) {
const float* ptr_in = din;
float* ptr_out = dout;
for (int i = 0; i < size; ++i) {
ptr_out[0] = 1.0 / sqrtf(ptr_in[0]);
ptr_in++;
ptr_out++;
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -65,6 +65,10 @@ void act_hard_sigmoid(const T* din,
const float slope,
const float offset,
int threads);
template <typename T>
void act_rsqrt(const T* din, T* dout, int size, int threads);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -159,6 +159,16 @@ void HardSigmoidCompute::Run() {
x_data, output_data, x_dims.production(), slope, offset, ctx.threads());
}
void RsqrtCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_rsqrt<float>(
x_data, output_data, x_dims.production(), ctx.threads());
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -245,3 +255,8 @@ REGISTER_LITE_KERNEL(hard_sigmoid,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
rsqrt, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::RsqrtCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -130,6 +130,15 @@ class HardSigmoidCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~HardSigmoidCompute() = default;
};
class RsqrtCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~RsqrtCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -117,6 +117,7 @@ REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
......@@ -33,7 +33,8 @@ enum activation_type_test {
RELU6,
LOG,
EXP,
FLOOR
FLOOR,
RSQRT
};
class ActivationComputeTester : public arena::TestCase {
......@@ -177,6 +178,12 @@ class ActivationComputeTester : public arena::TestCase {
}
break;
}
case RSQRT: {
for (int i = 0; i < dims_.production(); i++) {
output_data[i] = 1.0 / std::sqrt(x_data[i]);
}
break;
}
default:
LOG(INFO) << "the type of activation is unknow.";
}
......@@ -205,7 +212,7 @@ class ActivationComputeTester : public arena::TestCase {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
sign = type_ == "log" ? 1 : sign;
sign = (type_ == "log" || type_ == "rsqrt") ? 1 : sign;
data[i] = sign * static_cast<float>(i % 128) * 0.013f + 0.001;
}
SetCommonTensor(input_, dims_, data.data());
......@@ -553,5 +560,31 @@ TEST(Activation_floor, precision) {
#endif
}
TEST(Activation_rsqrt, precision) {
LOG(INFO) << "test rsqrt op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {2}) {
for (auto c : {2}) {
for (auto h : {2}) {
for (auto w : {2}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"rsqrt",
RSQRT));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册