未验证 提交 26450c49 编写于 作者: H huzhiqiang 提交者: GitHub

add floor op,elementwise_div op and assign op test=develop (#1882)

上级 5e8b15f5
......@@ -46,6 +46,8 @@ USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_max, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_div, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fusion_elementwise_div_activation, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fusion_elementwise_add_activation, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fusion_elementwise_mul_activation, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fusion_elementwise_max_activation, kARM, kFloat, kNCHW, def);
......@@ -118,6 +120,7 @@ USE_LITE_KERNEL(while, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(lod_reset, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(lookup_table, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(is_empty, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(assign, kARM, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_X86
......
......@@ -35,9 +35,11 @@ USE_LITE_OP(elementwise_add)
USE_LITE_OP(elementwise_sub)
USE_LITE_OP(elementwise_mul)
USE_LITE_OP(elementwise_max)
USE_LITE_OP(elementwise_div)
USE_LITE_OP(fusion_elementwise_add_activation)
USE_LITE_OP(fusion_elementwise_mul_activation)
USE_LITE_OP(fusion_elementwise_max_activation)
USE_LITE_OP(fusion_elementwise_div_activation)
USE_LITE_OP(square)
USE_LITE_OP(softmax)
USE_LITE_OP(dropout)
......@@ -68,6 +70,7 @@ USE_LITE_OP(yolo_box)
USE_LITE_OP(bilinear_interp)
USE_LITE_OP(nearest_interp)
USE_LITE_OP(assign);
USE_LITE_OP(crop)
USE_LITE_OP(prior_box)
USE_LITE_OP(density_prior_box)
......
......@@ -666,6 +666,17 @@ void act_exp(const float* din, float* dout, int size, int threads) {
}
}
template <>
void act_floor<float>(const float* din, float* dout, int size, int threads) {
const float* ptr_in = din;
float* ptr_out = dout;
for (int i = 0; i < size; ++i) {
ptr_out[0] = floorf(ptr_in[0]);
ptr_in++;
ptr_out++;
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -55,6 +55,9 @@ void act_log(const T* din, T* dout, int size, int threads);
template <typename T>
void act_exp(const T* din, T* dout, int size, int threads);
template <typename T>
void act_floor(const T* din, T* dout, int size, int threads);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -752,6 +752,293 @@ void elementwise_max_relu_broadcast<float>(const float* dinx,
}
}
template <>
void elementwise_div<float>(const float* dinx,
const float* diny,
float* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
#ifdef __aarch64__
dinx0 = vdivq_f32(dinx0, diny0);
dinx1 = vdivq_f32(dinx1, diny1);
dinx2 = vdivq_f32(dinx2, diny2);
dinx3 = vdivq_f32(dinx3, diny3);
#else
dinx0 = div_ps(dinx0, diny0);
dinx1 = div_ps(dinx1, diny1);
dinx2 = div_ps(dinx2, diny2);
dinx3 = div_ps(dinx3, diny3);
#endif
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *dinx_ptr / *diny_ptr;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <>
void elementwise_div_broadcast<float>(const float* dinx,
const float* diny,
float* dout,
int batch,
int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
#ifdef __aarch64__
din0 = vdivq_f32(din0, rb);
din1 = vdivq_f32(din1, rb);
din2 = vdivq_f32(din2, rb);
din3 = vdivq_f32(din3, rb);
#else
din0 = div_ps(din0, rb);
din1 = div_ps(din1, rb);
din2 = div_ps(din2, rb);
din3 = div_ps(din3, rb);
#endif
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
#ifdef __aarch64__
din0 = vdivq_f32(din0, rb);
din1 = vdivq_f32(din1, rb);
#else
din0 = div_ps(din0, rb);
din1 = div_ps(din1, rb);
#endif
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
#ifdef __aarch64__
din0 = vdivq_f32(din0, rb);
#else
din0 = div_ps(din0, rb);
#endif
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
*dout_ptr = *din_ptr / diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
}
template <>
void elementwise_div_relu<float>(const float* dinx,
const float* diny,
float* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < cnt; ++i) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
#ifdef __aarch64__
dinx0 = vdivq_f32(dinx0, diny0);
dinx1 = vdivq_f32(dinx1, diny1);
dinx2 = vdivq_f32(dinx2, diny2);
dinx3 = vdivq_f32(dinx3, diny3);
#else
dinx0 = div_ps(dinx0, diny0);
dinx1 = div_ps(dinx1, diny1);
dinx2 = div_ps(dinx2, diny2);
dinx3 = div_ps(dinx3, diny3);
#endif
// relu
dinx0 = vmaxq_f32(dinx0, vzero);
dinx1 = vmaxq_f32(dinx1, vzero);
dinx2 = vmaxq_f32(dinx2, vzero);
dinx3 = vmaxq_f32(dinx3, vzero);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; ++i) {
float tmp = *dinx_ptr / *diny_ptr;
*(dout_ptr++) = tmp > 0.f ? tmp : 0.f;
dinx_ptr++;
diny_ptr++;
}
}
}
template <>
void elementwise_div_relu_broadcast<float>(const float* dinx,
const float* diny,
float* dout,
int batch,
int channels,
int num) {
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
#ifdef __aarch64__
din0 = vdivq_f32(din0, rb);
din1 = vdivq_f32(din1, rb);
din2 = vdivq_f32(din2, rb);
din3 = vdivq_f32(din3, rb);
#else
din0 = div_ps(din0, rb);
din1 = div_ps(din1, rb);
din2 = div_ps(din2, rb);
din3 = div_ps(din3, rb);
#endif
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
din2 = vmaxq_f32(din2, vzero);
din3 = vmaxq_f32(din3, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
#ifdef __aarch64__
din0 = vdivq_f32(din0, rb);
din1 = vdivq_f32(din1, rb);
#else
din0 = div_ps(din0, rb);
din1 = div_ps(din1, rb);
#endif
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
#ifdef __aarch64__
din0 = vdivq_f32(din0, rb);
#else
din0 = div_ps(din0, rb);
#endif
// relu
din0 = vmaxq_f32(din0, vzero);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
float tmp = *din_ptr / diny_data;
*dout_ptr = tmp > 0.f ? tmp : 0.f;
dout_ptr++;
din_ptr++;
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -61,6 +61,20 @@ template <typename T>
void elementwise_max_relu_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
template <typename T>
void elementwise_div(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_div_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
template <typename T>
void elementwise_div_relu(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_div_relu_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -45,6 +45,7 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li
add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific
add_kernel(gru_unit_compute_arm ARM extra SRCS gru_unit_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -137,6 +137,16 @@ void ExpCompute::Run() {
x_data, output_data, x_dims.production(), ctx.threads());
}
void FloorCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_floor<float>(
x_data, output_data, x_dims.production(), ctx.threads());
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -209,3 +219,8 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
floor, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::FloorCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -112,6 +112,15 @@ class ExpCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~ExpCompute() = default;
};
class FloorCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~FloorCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/assign_compute.h"
#include <vector>
#include "lite/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void AssignCompute::PrepareForRun() {
// CHECK_OR_FALSE(param_t.Out);
}
void AssignCompute::Run() {
// LOG(INFO) << "into kernel compute run";
auto& param = Param<param_t>();
const lite::Tensor* input = param.X;
lite::Tensor* output = param.Out;
output->CopyDataFrom(*input);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
assign, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::AssignCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/operators/assign_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class AssignCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::AssignParam;
void PrepareForRun() override;
void Run() override;
virtual ~AssignCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -205,6 +205,55 @@ void ElementwiseMaxActivationCompute::Run() {
}
}
void ElementwiseDivCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_div_broadcast(
x_data, y_data, out_data, pre, n, post);
} else {
lite::arm::math::elementwise_div(
x_data, y_data, out_data, x_dims.production());
}
}
void ElementwiseDivActivationCompute::Run() {
auto& param = Param<operators::FusionElementwiseActivationParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int axis = param.axis;
std::string act_type = param.act_type;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") {
lite::arm::math::elementwise_div_relu_broadcast(
x_data, y_data, out_data, pre, n, post);
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
} else {
if (act_type == "relu") {
lite::arm::math::elementwise_div_relu(
x_data, y_data, out_data, x_dims.production());
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
}
for (int i = 0; i < x_dims.production(); i++) {
LOG(INFO) << "x:" << x_data[i] << " y:" << y_data[i]
<< " out:" << out_data[i];
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -278,3 +327,26 @@ REGISTER_LITE_KERNEL(
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_div,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseDivCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_div_activation,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseDivActivationCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -70,6 +70,22 @@ class ElementwiseMaxActivationCompute
virtual ~ElementwiseMaxActivationCompute() = default;
};
class ElementwiseDivCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseDivCompute() = default;
};
class ElementwiseDivActivationCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseDivActivationCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -59,6 +59,7 @@ add_operator(shape_op_lite basic SRCS shape_op.cc DEPS ${op_DEPS})
add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_DEPS})
add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS})
add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS})
add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
......@@ -110,6 +110,7 @@ REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(floor, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/assign_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool AssignOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool AssignOpLite::InferShape() const {
lite::DDim input_dims;
input_dims = param_.X->dims();
param_.Out->Resize(lite::DDim(input_dims));
return true;
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AssignOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto input = op_desc.Input("X").front();
auto out = op_desc.Output("Out").front();
param_.X = scope->FindVar(input)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out));
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(assign, paddle::lite::operators::AssignOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class AssignOpLite : public OpLite {
public:
AssignOpLite() {}
explicit AssignOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "assign"; }
private:
mutable AssignParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -89,6 +89,7 @@ REGISTER_LITE_OP(elementwise_add, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_mul, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_max, paddle::lite::operators::ElementwiseOp);
REGISTER_LITE_OP(elementwise_div, paddle::lite::operators::ElementwiseOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(elementwise_sub_grad,
......
......@@ -97,6 +97,8 @@ REGISTER_LITE_OP(fusion_elementwise_mul_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
REGISTER_LITE_OP(fusion_elementwise_max_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
REGISTER_LITE_OP(fusion_elementwise_div_activation,
paddle::lite::operators::FusionElementwiseActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(
......
......@@ -719,6 +719,12 @@ struct MatMulParam {
bool transpose_Y{false};
float alpha{1.0f};
};
/// ----------------------- assign operators -----------------------
struct AssignParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -13,6 +13,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_assign_compute SRCS assign_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
......@@ -32,7 +32,8 @@ enum activation_type_test {
SWISH,
RELU6,
LOG,
EXP
EXP,
FLOOR
};
class ActivationComputeTester : public arena::TestCase {
......@@ -170,6 +171,12 @@ class ActivationComputeTester : public arena::TestCase {
}
break;
}
case FLOOR: {
for (int i = 0; i < dims_.production(); i++) {
output_data[i] = std::floor(x_data[i]);
}
break;
}
default:
LOG(INFO) << "the type of activation is unknow.";
}
......@@ -519,5 +526,32 @@ TEST(Activation_exp, precision) {
#endif
}
TEST(Activation_floor, precision) {
LOG(INFO) << "test floor op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (auto n : {1, 3}) {
for (auto c : {3, 6}) {
for (auto h : {9, 18}) {
for (auto w : {9, 18}) {
std::unique_ptr<arena::TestCase> tester(new ActivationComputeTester(
place,
"def",
0.01,
6.,
"all",
0.,
DDim(std::vector<int64_t>({n, c, h, w})),
"floor",
FLOOR));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class AssignComputeTester : public arena::TestCase {
protected:
std::string input_ = "X";
std::string output_ = "Out";
DDim dims_{{100, 20}};
public:
AssignComputeTester(const Place& place, const std::string& alias)
: TestCase(place, alias) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
const auto* x_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("assign");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(input_, dims_, data.data());
}
};
void TestAssign(const Place& place) {
std::unique_ptr<arena::TestCase> tester(
new AssignComputeTester(place, "def"));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
TEST(Assign, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
TestAssign(place);
#endif
}
} // namespace lite
} // namespace paddle
......@@ -350,6 +350,125 @@ class FusionElementwiseMaxActivationComputeTester : public arena::TestCase {
}
};
class ElementwiseDivComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
DDim dims_{{1, 2, 3, 4}};
public:
ElementwiseDivComputeTester(const Place& place,
const std::string& alias,
int axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
auto* y = scope->FindTensor(inputy_);
const auto* y_data = y->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] / y_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("elementwise_div");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
std::vector<float> data2(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data2[i] = (i + 1) * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data2.data());
}
};
class FusionElementwiseDivActivationComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
std::string act_type_;
DDim dims_{{1, 2, 3, 4}};
public:
FusionElementwiseDivActivationComputeTester(const Place& place,
const std::string& alias,
int axis,
std::string act_type)
: TestCase(place, alias), axis_(axis), act_type_(act_type) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
auto* y = scope->FindTensor(inputy_);
const auto* y_data = y->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] / y_data[i];
if (act_type_ == "relu") {
out_data[i] = out_data[i] > 0 ? out_data[i] : 0;
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type_;
}
LOG(INFO) << "fusion div resul:" << out_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fusion_elementwise_div_activation");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
op_desc->SetAttr("act_type", act_type_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
std::vector<float> data2(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data2[i] = (i + 1) * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data2.data());
}
};
void test_elementwise(Place place) {
for (int axis : {-1, 0, 1, 3}) {
std::unique_ptr<arena::TestCase> tester(
......@@ -366,6 +485,11 @@ void test_elementwise(Place place) {
new ElementwiseMaxComputeTester(place, "def", axis));
arena::Arena arena_max(std::move(tester_max), place, 2e-5);
arena_max.TestPrecision();
std::unique_ptr<arena::TestCase> tester_div(
new ElementwiseDivComputeTester(place, "def", axis));
arena::Arena arena_div(std::move(tester_div), place, 2e-5);
arena_div.TestPrecision();
}
}
......@@ -398,6 +522,12 @@ void test_fusion_elementwise(Place place) {
place, "def", axis, "relu"));
arena::Arena arena_max_act(std::move(tester_max_act), place, 2e-5);
arena_max_act.TestPrecision();
std::unique_ptr<arena::TestCase> tester_div_act(
new FusionElementwiseDivActivationComputeTester(
place, "def", axis, "relu"));
arena::Arena arena_div_act(std::move(tester_div_act), place, 2e-5);
arena_div_act.TestPrecision();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册