提交 62ea82d0 编写于 作者: W Wilber 提交者: cyj1986

add elementwise_sub and modify argmax (#1964)

上级 111db475
...@@ -45,6 +45,7 @@ USE_LITE_KERNEL(box_coder, kARM, kFloat, kNCHW, def); ...@@ -45,6 +45,7 @@ USE_LITE_KERNEL(box_coder, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(depthwise_conv2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_add, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_sub, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_mul, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_max, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_max, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_div, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(elementwise_div, kARM, kFloat, kNCHW, def);
......
...@@ -51,7 +51,7 @@ USE_LITE_OP(batch_norm) ...@@ -51,7 +51,7 @@ USE_LITE_OP(batch_norm)
USE_LITE_OP(fusion_elementwise_sub_activation) USE_LITE_OP(fusion_elementwise_sub_activation)
USE_LITE_OP(transpose) USE_LITE_OP(transpose)
USE_LITE_OP(transpose2) USE_LITE_OP(transpose2)
USE_LITE_OP(argmax) USE_LITE_OP(arg_max)
USE_LITE_OP(axpy) USE_LITE_OP(axpy)
USE_LITE_OP(leaky_relu) USE_LITE_OP(leaky_relu)
USE_LITE_OP(relu_clipped) USE_LITE_OP(relu_clipped)
......
...@@ -266,6 +266,251 @@ void elementwise_add_relu_broadcast<float>(const float* dinx, ...@@ -266,6 +266,251 @@ void elementwise_add_relu_broadcast<float>(const float* dinx,
} }
} }
template <>
void elementwise_sub<float>(const float* dinx,
const float* diny,
float* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
dinx0 = vsubq_f32(dinx0, diny0);
dinx1 = vsubq_f32(dinx1, diny1);
dinx2 = vsubq_f32(dinx2, diny2);
dinx3 = vsubq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *dinx_ptr - *diny_ptr;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <>
void elementwise_sub_relu<float>(const float* dinx,
const float* diny,
float* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
dinx0 = vsubq_f32(dinx0, diny0);
dinx1 = vsubq_f32(dinx1, diny1);
dinx2 = vsubq_f32(dinx2, diny2);
dinx3 = vsubq_f32(dinx3, diny3);
// relu
dinx0 = vmaxq_f32(dinx0, vzero);
dinx1 = vmaxq_f32(dinx1, vzero);
dinx2 = vmaxq_f32(dinx2, vzero);
dinx3 = vmaxq_f32(dinx3, vzero);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
float tmp = *dinx_ptr - *diny_ptr;
*dout_ptr = tmp > 0.f ? tmp : 0.f;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <>
void elementwise_sub_broadcast<float>(const float* dinx,
const float* diny,
float* dout,
int batch,
int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vsubq_f32(din0, rb);
din1 = vsubq_f32(din1, rb);
din2 = vsubq_f32(din2, rb);
din3 = vsubq_f32(din3, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vsubq_f32(din0, rb);
din1 = vsubq_f32(din1, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vsubq_f32(din0, rb);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
*dout_ptr = *din_ptr - diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
}
template <>
void elementwise_sub_relu_broadcast<float>(const float* dinx,
const float* diny,
float* dout,
int batch,
int channels,
int num) {
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vsubq_f32(din0, rb);
din1 = vsubq_f32(din1, rb);
din2 = vsubq_f32(din2, rb);
din3 = vsubq_f32(din3, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
din2 = vmaxq_f32(din2, vzero);
din3 = vmaxq_f32(din3, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vsubq_f32(din0, rb);
din1 = vsubq_f32(din1, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vsubq_f32(din0, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
float tmp = *din_ptr - diny_data;
*dout_ptr = tmp > 0.f ? tmp : 0.f;
dout_ptr++;
din_ptr++;
}
}
}
}
}
template <> template <>
void elementwise_mul<float>(const float* dinx, void elementwise_mul<float>(const float* dinx,
const float* diny, const float* diny,
......
...@@ -33,6 +33,20 @@ template <typename T> ...@@ -33,6 +33,20 @@ template <typename T>
void elementwise_add_relu_broadcast( void elementwise_add_relu_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num); const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
template <typename T>
void elementwise_sub(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_sub_relu(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_sub_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
template <typename T>
void elementwise_sub_relu_broadcast(
const T* dinx, const T* diny, T* dout, int batch, int channels, int num);
template <typename T> template <typename T>
void elementwise_mul(const T* dinx, const T* diny, T* dout, int num); void elementwise_mul(const T* dinx, const T* diny, T* dout, int num);
......
...@@ -40,8 +40,12 @@ void ArgmaxCompute::Run() { ...@@ -40,8 +40,12 @@ void ArgmaxCompute::Run() {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(arg_max,
argmax, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ArgmaxCompute, def) kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ArgmaxCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -68,7 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) { ...@@ -68,7 +68,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
TEST(argmax_arm, retrive_op) { TEST(argmax_arm, retrive_op) {
auto argmax = auto argmax =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>( KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"argmax"); "arg_max");
ASSERT_FALSE(argmax.empty()); ASSERT_FALSE(argmax.empty());
ASSERT_TRUE(argmax.front()); ASSERT_TRUE(argmax.front());
} }
...@@ -136,4 +136,4 @@ TEST(argmax_arm, compute) { ...@@ -136,4 +136,4 @@ TEST(argmax_arm, compute) {
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
USE_LITE_KERNEL(argmax, kARM, kFloat, kNCHW, def); USE_LITE_KERNEL(arg_max, kARM, kFloat, kNCHW, def);
...@@ -116,6 +116,51 @@ void ElementwiseAddActivationCompute::Run() { ...@@ -116,6 +116,51 @@ void ElementwiseAddActivationCompute::Run() {
} }
} }
void ElementwiseSubCompute::Run() {
auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int axis = param.axis;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast(
x_data, y_data, out_data, pre, n, post);
} else {
lite::arm::math::elementwise_sub(
x_data, y_data, out_data, x_dims.production());
}
}
void ElementwiseSubActivationCompute::Run() {
auto& param = Param<operators::FusionElementwiseActivationParam>();
const float* x_data = param.X->data<float>();
const float* y_data = param.Y->data<float>();
float* out_data = param.Out->mutable_data<float>();
int axis = param.axis;
std::string act_type = param.act_type;
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") {
lite::arm::math::elementwise_sub_relu_broadcast(
x_data, y_data, out_data, pre, n, post);
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
} else {
if (act_type == "relu") {
lite::arm::math::elementwise_sub_relu(
x_data, y_data, out_data, x_dims.production());
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
}
}
void ElementwiseMulCompute::Run() { void ElementwiseMulCompute::Run() {
auto& param = Param<operators::ElementwiseParam>(); auto& param = Param<operators::ElementwiseParam>();
const float* x_data = param.X->data<float>(); const float* x_data = param.X->data<float>();
...@@ -249,10 +294,6 @@ void ElementwiseDivActivationCompute::Run() { ...@@ -249,10 +294,6 @@ void ElementwiseDivActivationCompute::Run() {
LOG(FATAL) << "unsupported Activation type: " << act_type; LOG(FATAL) << "unsupported Activation type: " << act_type;
} }
} }
for (int i = 0; i < x_dims.production(); i++) {
LOG(INFO) << "x:" << x_data[i] << " y:" << y_data[i]
<< " out:" << out_data[i];
}
} }
} // namespace arm } // namespace arm
...@@ -283,6 +324,29 @@ REGISTER_LITE_KERNEL( ...@@ -283,6 +324,29 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(elementwise_sub,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseSubCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
fusion_elementwise_sub_activation,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::ElementwiseSubActivationCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(elementwise_mul, REGISTER_LITE_KERNEL(elementwise_mul,
kARM, kARM,
kFloat, kFloat,
......
...@@ -38,6 +38,22 @@ class ElementwiseAddActivationCompute ...@@ -38,6 +38,22 @@ class ElementwiseAddActivationCompute
virtual ~ElementwiseAddActivationCompute() = default; virtual ~ElementwiseAddActivationCompute() = default;
}; };
class ElementwiseSubCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseSubCompute() = default;
};
class ElementwiseSubActivationCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ElementwiseSubActivationCompute() = default;
};
class ElementwiseMulCompute class ElementwiseMulCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> { : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public: public:
......
...@@ -50,7 +50,7 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -50,7 +50,7 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>(); param_.X = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.Axis = op_desc.GetAttr<int>("Axis"); param_.Axis = op_desc.GetAttr<int64_t>("axis");
return true; return true;
} }
...@@ -59,4 +59,4 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -59,4 +59,4 @@ bool ArgmaxOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(argmax, paddle::lite::operators::ArgmaxOpLite); REGISTER_LITE_OP(arg_max, paddle::lite::operators::ArgmaxOpLite);
...@@ -761,7 +761,6 @@ struct GenerateProposalsParam { ...@@ -761,7 +761,6 @@ struct GenerateProposalsParam {
lite::Tensor* RpnRois{}; lite::Tensor* RpnRois{};
lite::Tensor* RpnRoiProbs{}; lite::Tensor* RpnRoiProbs{};
}; };
/// ----------------------- shape operators ----------------------
/// ----------------------- squeeze operators ---------------------- /// ----------------------- squeeze operators ----------------------
struct SqueezeParam { struct SqueezeParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
...@@ -25,7 +25,7 @@ class ArgmaxComputeTester : public arena::TestCase { ...@@ -25,7 +25,7 @@ class ArgmaxComputeTester : public arena::TestCase {
// common attributes for this op. // common attributes for this op.
std::string input_ = "x"; std::string input_ = "x";
std::string output_ = "out"; std::string output_ = "out";
int axis_ = 0.; int64_t axis_ = 0.;
DDim dims_{{2, 5, 20, 30}}; DDim dims_{{2, 5, 20, 30}};
public: public:
...@@ -82,10 +82,10 @@ class ArgmaxComputeTester : public arena::TestCase { ...@@ -82,10 +82,10 @@ class ArgmaxComputeTester : public arena::TestCase {
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("argmax"); op_desc->SetType("arg_max");
op_desc->SetInput("X", {input_}); op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_}); op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("Axis", axis_); op_desc->SetAttr("axis", axis_);
} }
void PrepareData() override { void PrepareData() override {
......
...@@ -71,6 +71,57 @@ class ElementwiseComputeTester : public arena::TestCase { ...@@ -71,6 +71,57 @@ class ElementwiseComputeTester : public arena::TestCase {
} }
}; };
class ElementwiseSubComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
DDim dims_{{1, 2, 3, 4}};
public:
ElementwiseSubComputeTester(const Place& place,
const std::string& alias,
int axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
auto* y = scope->FindTensor(inputy_);
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] - y_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("elementwise_sub");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class ElementwiseMulComputeTester : public arena::TestCase { class ElementwiseMulComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
...@@ -232,6 +283,65 @@ class FusionElementwiseAddActivationComputeTester : public arena::TestCase { ...@@ -232,6 +283,65 @@ class FusionElementwiseAddActivationComputeTester : public arena::TestCase {
} }
}; };
class FusionElementwiseSubActivationComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
std::string act_type_;
DDim dims_{{1, 2, 3, 4}};
public:
FusionElementwiseSubActivationComputeTester(const Place& place,
const std::string& alias,
int axis,
std::string act_type)
: TestCase(place, alias), axis_(axis), act_type_(act_type) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
auto* y = scope->FindTensor(inputy_);
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] - y_data[i];
if (act_type_ == "relu") {
out_data[i] = out_data[i] > 0 ? out_data[i] : 0;
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type_;
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fusion_elementwise_sub_activation");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
op_desc->SetAttr("act_type", act_type_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class FusionElementwiseMulActivationComputeTester : public arena::TestCase { class FusionElementwiseMulActivationComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. // common attributes for this op.
...@@ -441,7 +551,6 @@ class FusionElementwiseDivActivationComputeTester : public arena::TestCase { ...@@ -441,7 +551,6 @@ class FusionElementwiseDivActivationComputeTester : public arena::TestCase {
} else { } else {
LOG(FATAL) << "unsupported Activation type: " << act_type_; LOG(FATAL) << "unsupported Activation type: " << act_type_;
} }
LOG(INFO) << "fusion div resul:" << out_data[i];
} }
} }
...@@ -476,6 +585,11 @@ void test_elementwise(Place place) { ...@@ -476,6 +585,11 @@ void test_elementwise(Place place) {
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision(); arena.TestPrecision();
std::unique_ptr<arena::TestCase> tester_sub(
new ElementwiseSubComputeTester(place, "def", axis));
arena::Arena arena_sub(std::move(tester_sub), place, 2e-5);
arena_sub.TestPrecision();
std::unique_ptr<arena::TestCase> tester_mul( std::unique_ptr<arena::TestCase> tester_mul(
new ElementwiseMulComputeTester(place, "def", axis)); new ElementwiseMulComputeTester(place, "def", axis));
arena::Arena arena_mul(std::move(tester_mul), place, 2e-5); arena::Arena arena_mul(std::move(tester_mul), place, 2e-5);
...@@ -511,6 +625,12 @@ void test_fusion_elementwise(Place place) { ...@@ -511,6 +625,12 @@ void test_fusion_elementwise(Place place) {
arena::Arena arena_add_act(std::move(tester_add_act), place, 2e-5); arena::Arena arena_add_act(std::move(tester_add_act), place, 2e-5);
arena_add_act.TestPrecision(); arena_add_act.TestPrecision();
std::unique_ptr<arena::TestCase> tester_sub_act(
new FusionElementwiseSubActivationComputeTester(
place, "def", axis, "relu"));
arena::Arena arena_sub_act(std::move(tester_sub_act), place, 2e-5);
arena_sub_act.TestPrecision();
std::unique_ptr<arena::TestCase> tester_mul_act( std::unique_ptr<arena::TestCase> tester_mul_act(
new FusionElementwiseMulActivationComputeTester( new FusionElementwiseMulActivationComputeTester(
place, "def", axis, "relu")); place, "def", axis, "relu"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册