未验证 提交 dfce4621 编写于 作者: J juncaipeng 提交者: GitHub

add rsqrt op, test=develop (#2176)

上级 77811367
...@@ -127,3 +127,4 @@ USE_LITE_OP(roi_align) ...@@ -127,3 +127,4 @@ USE_LITE_OP(roi_align)
USE_LITE_OP(box_clip) USE_LITE_OP(box_clip)
USE_LITE_OP(assign_value) USE_LITE_OP(assign_value)
USE_LITE_OP(hard_sigmoid) USE_LITE_OP(hard_sigmoid)
USE_LITE_OP(rsqrt)
...@@ -688,6 +688,18 @@ void act_hard_sigmoid<float>(const float* din, ...@@ -688,6 +688,18 @@ void act_hard_sigmoid<float>(const float* din,
++dout; ++dout;
} }
} }
template <>
void act_rsqrt<float>(const float* din, float* dout, int size, int threads) {
const float* ptr_in = din;
float* ptr_out = dout;
for (int i = 0; i < size; ++i) {
ptr_out[0] = 1.0 / sqrtf(ptr_in[0]);
ptr_in++;
ptr_out++;
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -65,6 +65,10 @@ void act_hard_sigmoid(const T* din, ...@@ -65,6 +65,10 @@ void act_hard_sigmoid(const T* din,
const float slope, const float slope,
const float offset, const float offset,
int threads); int threads);
template <typename T>
void act_rsqrt(const T* din, T* dout, int size, int threads);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -159,6 +159,16 @@ void HardSigmoidCompute::Run() { ...@@ -159,6 +159,16 @@ void HardSigmoidCompute::Run() {
x_data, output_data, x_dims.production(), slope, offset, ctx.threads()); x_data, output_data, x_dims.production(), slope, offset, ctx.threads());
} }
void RsqrtCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_rsqrt<float>(
x_data, output_data, x_dims.production(), ctx.threads());
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -245,3 +255,8 @@ REGISTER_LITE_KERNEL(hard_sigmoid, ...@@ -245,3 +255,8 @@ REGISTER_LITE_KERNEL(hard_sigmoid,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(
rsqrt, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::RsqrtCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
...@@ -130,6 +130,15 @@ class HardSigmoidCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -130,6 +130,15 @@ class HardSigmoidCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~HardSigmoidCompute() = default; virtual ~HardSigmoidCompute() = default;
}; };
class RsqrtCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~RsqrtCompute() = default;
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -117,6 +117,7 @@ REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp); ...@@ -117,6 +117,7 @@ REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp); REGISTER_LITE_OP(hard_sigmoid, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(rsqrt, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN #ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp); REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
...@@ -33,7 +33,8 @@ enum activation_type_test { ...@@ -33,7 +33,8 @@ enum activation_type_test {
RELU6, RELU6,
LOG, LOG,
EXP, EXP,
FLOOR FLOOR,
RSQRT
}; };
class ActivationComputeTester : public arena::TestCase { class ActivationComputeTester : public arena::TestCase {
...@@ -177,6 +178,12 @@ class ActivationComputeTester : public arena::TestCase { ...@@ -177,6 +178,12 @@ class ActivationComputeTester : public arena::TestCase {
} }
break; break;
} }
case RSQRT: {
for (int i = 0; i < dims_.production(); i++) {
output_data[i] = 1.0 / std::sqrt(x_data[i]);
}
break;
}
default: default:
LOG(INFO) << "the type of activation is unknow."; LOG(INFO) << "the type of activation is unknow.";
} }
...@@ -205,7 +212,7 @@ class ActivationComputeTester : public arena::TestCase { ...@@ -205,7 +212,7 @@ class ActivationComputeTester : public arena::TestCase {
std::vector<float> data(dims_.production()); std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) { for (int i = 0; i < dims_.production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f; float sign = i % 3 == 0 ? -1.0f : 1.0f;
sign = type_ == "log" ? 1 : sign; sign = (type_ == "log" || type_ == "rsqrt") ? 1 : sign;
data[i] = sign * static_cast<float>(i % 128) * 0.013f + 0.001; data[i] = sign * static_cast<float>(i % 128) * 0.013f + 0.001;
} }
SetCommonTensor(input_, dims_, data.data()); SetCommonTensor(input_, dims_, data.data());
...@@ -553,5 +560,31 @@ TEST(Activation_floor, precision) { ...@@ -553,5 +560,31 @@ TEST(Activation_floor, precision) {
#endif #endif
} }
TEST(Activation_rsqrt, precision) {
LOG(INFO) << "test rsqrt op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {2}) {
for (auto c : {2}) {
for (auto h : {2}) {
for (auto w : {2}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"rsqrt",
RSQRT));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
#endif
}
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册