提交 907150a4 编写于 作者: Z zhupengyang 提交者: GitHub

[NPU] enhance elementwise uts (#2784)

* [NPU] reshape x,y,out node in elementwise ops
上级 a11eaf6a
......@@ -21,28 +21,42 @@ namespace lite {
namespace subgraph {
namespace npu {
std::vector<int64_t> CvtYShape(const DDim& x_dims,
const DDim& y_dims,
int axis) {
CHECK_EQ(x_dims.size(), 4UL) << "[NPU] Only support 4-dimension x";
CHECK_GE(x_dims.size(), y_dims.size());
void CvtYShape(std::vector<int64_t>* x_shape,
std::vector<int64_t>* y_shape,
int axis) {
CHECK_GE(x_shape->size(), y_shape->size());
if (axis < 0) {
axis += x_dims.size();
axis = x_shape->size() - y_shape->size();
}
std::vector<int64_t> y_new_shape(y_dims.Vectorize());
if (y_new_shape.size() == 4UL) {
return y_new_shape;
// only support:
// (n,c,h,w) * (n,c,h,w)
// (n,c,h,w) * (1,c,1,1)
// (n,c,h,w) * (1,c,h,1)
// (n,c,h,w) * (1,c,h,w)
int y_shape_size = y_shape->size();
if (y_shape_size == 1) {
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 2, 1);
} else if (y_shape_size == 2) {
y_shape->insert(y_shape->begin(), 1);
y_shape->insert(y_shape->end(), 1);
} else if (y_shape_size == 3) {
y_shape->insert(y_shape->begin(), 1);
}
for (int i = 0; i < axis; i++) {
y_new_shape.insert(y_new_shape.begin(), 1);
if (y_shape_size < 4) {
int n = 1;
for (int i = 0; i < axis; i++) {
n *= x_shape->at(i);
}
x_shape->erase(x_shape->begin(), x_shape->begin() + axis);
x_shape->insert(x_shape->begin(), n);
x_shape->insert(x_shape->end(), 4 - x_shape->size(), 1);
}
while (y_new_shape.size() < 4) {
y_new_shape.push_back(1);
}
CHECK_EQ(y_new_shape.size(), 4UL);
return y_new_shape;
CHECK_EQ(x_shape->size(), 4UL);
CHECK_EQ(y_shape->size(), 4UL);
}
int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
......@@ -61,32 +75,58 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
auto y_name = op_info->Input("Y").front();
auto y_type = kernel->GetInputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
auto y = scope->FindMutableTensor(y_name);
auto y_dims = y->dims();
auto out_name = op_info->Output("Out").front();
auto out_type = kernel->GetOutputDeclType("Out");
CHECK(out_type->precision() == PRECISION(kFloat));
CHECK(out_type->layout() == DATALAYOUT(kNCHW));
auto out = scope->FindMutableTensor(out_name);
auto out_dims = out->dims();
auto axis = op_info->GetAttr<int>("axis");
auto x_new_shape = x_dims.Vectorize();
auto y_new_shape = y_dims.Vectorize();
CvtYShape(&x_new_shape, &y_new_shape, axis);
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
if (x_dims.Vectorize() != x_new_shape) {
auto reshaped_x_node = graph->Add<ge::op::Reshape>(x_name + "/reshape");
auto reshaped_x_op = reshaped_x_node->data<ge::op::Reshape>();
reshaped_x_op->set_input_tensor(*x_node->data());
reshaped_x_op->set_attr_shape(
ge::AttrValue::LIST_INT(x_new_shape.begin(), x_new_shape.end()));
reshaped_x_op->set_attr_axis(0);
x_node = reshaped_x_node;
}
} else {
x_node = graph->Add(x_name, *x);
x_node = graph->Add(x_name, *x, x_new_shape);
}
// Y node
std::shared_ptr<Node> y_node = nullptr;
if (graph->Has(y_name)) {
y_node = graph->Get(y_name);
if (y_dims.Vectorize() != y_new_shape) {
auto reshaped_y_node = graph->Add<ge::op::Reshape>(y_name + "/reshape");
auto reshaped_y_op = reshaped_y_node->data<ge::op::Reshape>();
reshaped_y_op->set_input_tensor(*y_node->data());
reshaped_y_op->set_attr_shape(
ge::AttrValue::LIST_INT(y_new_shape.begin(), y_new_shape.end()));
reshaped_y_op->set_attr_axis(0);
y_node = reshaped_y_node;
}
} else {
auto y_new_shape = CvtYShape(x_dims, y_dims, axis);
y_node = graph->Add(y_name, *y, y_new_shape);
}
......@@ -98,17 +138,20 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto elt_op = elt_node->data<ge::op::Add>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
} else if (op_type == "elementwise_sub") {
} else if (op_type == "elementwise_sub" ||
op_type == "fusion_elementwise_sub_activation") {
elt_node = graph->Add<ge::op::Sub>(out_name);
auto elt_op = elt_node->data<ge::op::Sub>();
elt_op->set_input_x1(*x_node->data());
elt_op->set_input_x2(*y_node->data());
} else if (op_type == "elementwise_mul") {
} else if (op_type == "elementwise_mul" ||
op_type == "fusion_elementwise_mul_activation") {
elt_node = graph->Add<ge::op::Mul>(out_name);
auto elt_op = elt_node->data<ge::op::Mul>();
elt_op->set_input_x(*x_node->data());
elt_op->set_input_y(*y_node->data());
} else if (op_type == "elementwise_div") {
} else if (op_type == "elementwise_div" ||
op_type == "fusion_elementwise_div_activation") {
elt_node = graph->Add<ge::op::RealDiv>(out_name);
auto elt_op = elt_node->data<ge::op::RealDiv>();
elt_op->set_input_x1(*x_node->data());
......@@ -118,8 +161,22 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
return FAILED;
}
if (out_dims.Vectorize() != x_new_shape) {
auto reshaped_elt_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_elt_op = reshaped_elt_node->data<ge::op::Reshape>();
reshaped_elt_op->set_input_tensor(*elt_node->data());
auto out_shape = out_dims.Vectorize();
reshaped_elt_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
reshaped_elt_op->set_attr_axis(0);
elt_node = reshaped_elt_node;
}
// Act node
if (op_type == "fusion_elementwise_add_activation") {
if (op_type == "fusion_elementwise_add_activation" ||
op_type == "fusion_elementwise_sub_activation" ||
op_type == "fusion_elementwise_mul_activation" ||
op_type == "fusion_elementwise_div_activation") {
auto act_type = op_info->GetAttr<std::string>("act_type");
auto act_node = graph->Add<ge::op::Activation>(out_name);
auto act_op = act_node->data<ge::op::Activation>();
......@@ -128,6 +185,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// clipped_relu etc.
act_op->set_attr_mode(CvtActMode(act_type));
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......@@ -139,9 +197,6 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
REGISTER_SUBGRAPH_BRIDGE(elementwise_add,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(elementwise_sub,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
......@@ -151,3 +206,15 @@ REGISTER_SUBGRAPH_BRIDGE(elementwise_mul,
REGISTER_SUBGRAPH_BRIDGE(elementwise_div,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(fusion_elementwise_sub_activation,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(fusion_elementwise_mul_activation,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
REGISTER_SUBGRAPH_BRIDGE(fusion_elementwise_div_activation,
kNPU,
paddle::lite::subgraph::npu::ElementwiseConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/elementwise_ops.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void elementwise_add_ref(const std::shared_ptr<operators::ElementwiseOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindTensor("x");
auto y = scope->FindTensor("y");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<dtype>();
auto y_data = y->data<dtype>();
auto out_data = out->mutable_data<dtype>();
auto x_dims = x->dims();
auto y_dims = y->dims();
int axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis += x_dims.size();
}
int batch = 1;
int channels = y->numel();
int num = x->numel() / channels / batch;
// do elementwise add/sub/max...
std::string op_type = op_info->Type();
if (op_type == "elementwise_add") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (op_type == "elementwise_sub") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr - diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (op_type == "elementwise_mul") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr * diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (op_type == "elementwise_div") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr / diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (op_type == "elementwise_max") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = std::max(*din_ptr, diny_data);
dout_ptr++;
din_ptr++;
}
}
}
} else {
LOG(FATAL) << "unsupported Elementwise type: " << op_type;
}
}
void test_elementwise_add(const std::vector<int64_t>& x_shape,
const std::vector<int64_t>& y_shape,
int axis,
std::string elt_type) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string y_var_name = "y";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* y = scope.Var(y_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize(x_shape);
y->Resize(y_shape);
// initialize input&output data
FillTensor<float>(x, 1, 3);
FillTensor<float>(y, 1, 3);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("elementwise_" + elt_type);
opdesc.SetInput("X", {x_var_name});
opdesc.SetInput("Y", {y_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ElementwiseOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
elementwise_add_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, elementwise_add) {
for (auto elt_type : {"add", "sub", "mul", "div"}) {
test_elementwise_add({1, 2, 3, 4}, {2}, 1, elt_type);
test_elementwise_add({1, 2, 3, 4}, {1, 2, 1, 1}, 1, elt_type);
test_elementwise_add({1, 2, 3, 4}, {1, 2, 3, 4}, 3, elt_type);
}
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_NPU_BRIDGE(elementwise_add);
USE_LITE_OP(elementwise_sub);
USE_NPU_BRIDGE(elementwise_sub);
USE_LITE_OP(elementwise_mul);
USE_NPU_BRIDGE(elementwise_mul);
USE_LITE_OP(elementwise_div);
USE_NPU_BRIDGE(elementwise_div);
......@@ -30,10 +30,13 @@ USE_SUBGRAPH_BRIDGE(conv2d_transpose, kNPU);
USE_SUBGRAPH_BRIDGE(dropout, kNPU);
USE_SUBGRAPH_BRIDGE(elementwise_add, kNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kNPU);
USE_SUBGRAPH_BRIDGE(elementwise_sub, kNPU);
USE_SUBGRAPH_BRIDGE(elementwise_mul, kNPU);
USE_SUBGRAPH_BRIDGE(elementwise_div, kNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_add_activation, kNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_sub_activation, kNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_mul_activation, kNPU);
USE_SUBGRAPH_BRIDGE(fusion_elementwise_div_activation, kNPU);
USE_SUBGRAPH_BRIDGE(fc, kNPU);
USE_SUBGRAPH_BRIDGE(bilinear_interp, kNPU);
......
......@@ -16,654 +16,228 @@
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
#define ELT(MATHOP) \
for (int n = 0; n < xn; n++) { \
for (int c = 0; c < xc; c++) { \
for (int h = 0; h < xh; h++) { \
for (int w = 0; w < xw; w++) { \
int x_offset = n * xc * xh * xw + c * xh * xw + h * xw + w; \
int y_offset = 0; \
if (yn != 1) y_offset += n * yc * yh * yw; \
if (yc != 1) y_offset += c * yh * yw; \
if (yh != 1) y_offset += h * yw; \
if (yw != 1) y_offset += w; \
out_data[x_offset] = out_data[x_offset] MATHOP y_data[y_offset]; \
} \
} \
} \
}
class ElementwiseComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
DDim dims_{{1, 2, 3, 4}};
std::string x_ = "x";
std::string y_ = "y";
std::string out_ = "out";
// add, sub, mul, div, max
std::string elt_type_ = "";
DDim x_dims_{{1, 2, 3, 4}};
DDim y_dims_{{1, 2, 3, 4}};
int axis_ = 1;
std::string act_type_ = "";
public:
ElementwiseComputeTester(const Place& place,
const std::string& alias,
int axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] + y_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("elementwise_add");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class ElementwiseSubComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
DDim dims_{{1, 2, 3, 4}};
public:
ElementwiseSubComputeTester(const Place& place,
const std::string& alias,
int axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] - y_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("elementwise_sub");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class ElementwiseMulComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
DDim dims_{{1, 2, 3, 4}};
public:
ElementwiseMulComputeTester(const Place& place,
const std::string& alias,
int axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] * y_data[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("elementwise_mul");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class ElementwiseMaxComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
DDim dims_{{1, 2, 3, 4}};
public:
ElementwiseMaxComputeTester(const Place& place,
const std::string& alias,
int axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = std::max(x_data[i], y_data[i]);
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("elementwise_max");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class FusionElementwiseAddActivationComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
std::string act_type_;
DDim dims_{{1, 2, 3, 4}};
public:
FusionElementwiseAddActivationComputeTester(const Place& place,
const std::string& alias,
int axis,
std::string act_type)
: TestCase(place, alias), axis_(axis), act_type_(act_type) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] + y_data[i];
if (act_type_ == "relu") {
out_data[i] = out_data[i] > 0 ? out_data[i] : 0;
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type_;
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fusion_elementwise_add_activation");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
op_desc->SetAttr("act_type", act_type_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class FusionElementwiseSubActivationComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
std::string act_type_;
DDim dims_{{1, 2, 3, 4}};
public:
FusionElementwiseSubActivationComputeTester(const Place& place,
const std::string& alias,
int axis,
std::string act_type)
: TestCase(place, alias), axis_(axis), act_type_(act_type) {}
std::string elt_type = "add",
std::vector<int64_t> x_shape = {1, 2, 3, 4},
std::vector<int64_t> y_shape = {1, 2, 3, 4},
int axis = 1,
std::string act_type = "")
: TestCase(place, alias),
elt_type_(elt_type),
x_dims_(DDim(x_shape)),
y_dims_(DDim(y_shape)),
axis_(axis),
act_type_(act_type) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] - y_data[i];
if (act_type_ == "relu") {
out_data[i] = out_data[i] > 0 ? out_data[i] : 0;
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type_;
if (axis_ < 0) {
axis_ = x_dims_.size() - y_dims_.size();
}
auto x_shape = x_dims_.Vectorize();
while (x_shape.size() < 4) {
x_shape.push_back(1);
}
auto y_shape = y_dims_.Vectorize();
y_shape.insert(y_shape.begin(), axis_, 1);
while (y_shape.size() < 4) {
y_shape.push_back(1);
}
CHECK_EQ(x_shape.size(), 4);
CHECK_EQ(y_shape.size(), 4);
auto x = scope->FindTensor(x_);
auto y = scope->FindTensor(y_);
auto x_data = x->data<float>();
auto y_data = y->data<float>();
auto out = scope->NewTensor(out_);
out->Resize(x_dims_);
auto out_data = out->mutable_data<float>();
memcpy(out_data, x_data, sizeof(float) * x_dims_.production());
int xn = x_shape[0];
int xc = x_shape[1];
int xh = x_shape[2];
int xw = x_shape[3];
int yn = y_shape[0];
int yc = y_shape[1];
int yh = y_shape[2];
int yw = y_shape[3];
if (elt_type_ == "add") {
ELT(+);
} else if (elt_type_ == "sub") {
ELT(-);
} else if (elt_type_ == "mul") {
ELT(*);
} else if (elt_type_ == "div") {
ELT(/);
} else if (elt_type_ == "max") {
for (int n = 0; n < xn; n++) {
for (int c = 0; c < xc; c++) {
for (int h = 0; h < xh; h++) {
for (int w = 0; w < xw; w++) {
int x_offset = n * xc * xh * xw + c * xh * xw + h * xw + w;
int y_offset = 0;
if (yn != 1) y_offset += n * yc * yh * yw;
if (yc != 1) y_offset += c * yh * yw;
if (yh != 1) y_offset += h * yw;
if (yw != 1) y_offset += w;
out_data[x_offset] =
std::max(out_data[x_offset], y_data[y_offset]);
}
}
}
}
} else {
LOG(FATAL) << "unsupported";
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fusion_elementwise_sub_activation");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
op_desc->SetAttr("act_type", act_type_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class FusionElementwiseMulActivationComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
std::string act_type_;
DDim dims_{{1, 2, 3, 4}};
public:
FusionElementwiseMulActivationComputeTester(const Place& place,
const std::string& alias,
int axis,
std::string act_type)
: TestCase(place, alias), axis_(axis), act_type_(act_type) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] * y_data[i];
if (!act_type_.empty()) {
if (act_type_ == "relu") {
out_data[i] = out_data[i] > 0 ? out_data[i] : 0;
for (int i = 0; i < x_dims_.production(); i++) {
out_data[i] = std::max(0.f, out_data[i]);
}
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type_;
LOG(FATAL) << "unsupported";
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fusion_elementwise_mul_activation");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
std::string op_type = "elementwise_" + elt_type_;
if (!act_type_.empty()) {
op_type = "fusion_" + op_type + "_activation";
}
op_desc->SetType(op_type);
op_desc->SetInput("X", {x_});
op_desc->SetInput("Y", {y_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("axis", axis_);
op_desc->SetAttr("act_type", act_type_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class FusionElementwiseMaxActivationComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
std::string act_type_;
DDim dims_{{1, 2, 3, 4}};
public:
FusionElementwiseMaxActivationComputeTester(const Place& place,
const std::string& alias,
int axis,
std::string act_type)
: TestCase(place, alias), axis_(axis), act_type_(act_type) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
const auto* y_data = x->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = std::max(x_data[i], y_data[i]);
if (act_type_ == "relu") {
out_data[i] = out_data[i] > 0 ? out_data[i] : 0;
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type_;
}
if (!act_type_.empty()) {
op_desc->SetAttr("act_type", act_type_);
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fusion_elementwise_max_activation");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
op_desc->SetAttr("act_type", act_type_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
std::vector<float> dx(x_dims_.production());
for (size_t i = 0; i < dx.size(); i++) {
dx[i] = (i % 3) * 1.1f;
dx[i] = dx[i] == 0 ? 1.f : dx[i];
}
SetCommonTensor(x_, x_dims_, dx.data());
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data.data());
}
};
class ElementwiseDivComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
DDim dims_{{1, 2, 3, 4}};
public:
ElementwiseDivComputeTester(const Place& place,
const std::string& alias,
int axis)
: TestCase(place, alias), axis_(axis) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
auto* y = scope->FindTensor(inputy_);
const auto* y_data = y->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] / y_data[i];
std::vector<float> dy(y_dims_.production());
for (size_t i = 0; i < dy.size(); i++) {
dy[i] = (i % 5) * 1.1f;
dy[i] = dy[i] == 0 ? 1.f : dy[i];
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("elementwise_div");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
std::vector<float> data2(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data2[i] = (i + 1) * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data2.data());
SetCommonTensor(y_, y_dims_, dy.data());
}
};
class FusionElementwiseDivActivationComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string inputx_ = "x";
std::string inputy_ = "y";
std::string output_ = "out";
int axis_;
std::string act_type_;
DDim dims_{{1, 2, 3, 4}};
public:
FusionElementwiseDivActivationComputeTester(const Place& place,
const std::string& alias,
int axis,
std::string act_type)
: TestCase(place, alias), axis_(axis), act_type_(act_type) {}
void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(inputx_);
const auto* x_data = x->data<float>();
auto* y = scope->FindTensor(inputy_);
const auto* y_data = y->data<float>();
for (int i = 0; i < dims_.production(); i++) {
out_data[i] = x_data[i] / y_data[i];
if (act_type_ == "relu") {
out_data[i] = out_data[i] > 0 ? out_data[i] : 0;
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type_;
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("fusion_elementwise_div_activation");
op_desc->SetInput("X", {inputx_});
op_desc->SetInput("Y", {inputy_});
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("axis", axis_);
op_desc->SetAttr("act_type", act_type_);
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1.1;
}
std::vector<float> data2(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data2[i] = (i + 1) * 1.1;
}
SetCommonTensor(inputx_, dims_, data.data());
SetCommonTensor(inputy_, dims_, data2.data());
}
};
void test_elementwise(Place place) {
for (int axis : {-1, 0, 1, 3}) {
std::unique_ptr<arena::TestCase> tester(
new ElementwiseComputeTester(place, "def", axis));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
std::unique_ptr<arena::TestCase> tester_sub(
new ElementwiseSubComputeTester(place, "def", axis));
arena::Arena arena_sub(std::move(tester_sub), place, 2e-5);
arena_sub.TestPrecision();
std::unique_ptr<arena::TestCase> tester_mul(
new ElementwiseMulComputeTester(place, "def", axis));
arena::Arena arena_mul(std::move(tester_mul), place, 2e-5);
arena_mul.TestPrecision();
std::unique_ptr<arena::TestCase> tester_max(
new ElementwiseMaxComputeTester(place, "def", axis));
arena::Arena arena_max(std::move(tester_max), place, 2e-5);
arena_max.TestPrecision();
std::unique_ptr<arena::TestCase> tester_div(
new ElementwiseDivComputeTester(place, "def", axis));
arena::Arena arena_div(std::move(tester_div), place, 2e-5);
arena_div.TestPrecision();
}
// add sub mul div max +act
void TestElt(Place place,
float abs_error,
std::string elt_type,
std::vector<int64_t> x_shape,
std::vector<int64_t> y_shape,
int axis,
std::string act_type = "") {
std::unique_ptr<arena::TestCase> tester(new ElementwiseComputeTester(
place, "def", elt_type, x_shape, y_shape, axis, act_type));
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
TEST(Elementwise, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_elementwise(place);
#endif
void TestEltDims(Place place, float abs_error) {
TestElt(place, abs_error, "add", {2, 3, 4, 5}, {2, 3, 4, 5}, 0);
TestElt(place, abs_error, "add", {2, 3, 4}, {2, 3, 4}, 0);
TestElt(place, abs_error, "add", {2, 3, 4}, {2, 3}, 0);
TestElt(place, abs_error, "add", {2, 3}, {2}, 0);
TestElt(place, abs_error, "add", {2, 3, 4, 5}, {3, 4}, 1);
TestElt(place, abs_error, "add", {2, 3, 4}, {3, 4}, 1);
TestElt(place, abs_error, "add", {2, 3}, {3}, 1);
TestElt(place, abs_error, "add", {2, 3, 4, 5}, {4, 5}, 2);
TestElt(place, abs_error, "add", {2, 3, 4}, {4}, 2);
TestElt(place, abs_error, "add", {2, 3, 4, 5}, {5}, 3);
TestElt(place, abs_error, "add", {2, 3, 4, 5}, {3, 4, 5}, -1);
TestElt(place, abs_error, "add", {2, 3, 4}, {3, 4}, -1);
}
void test_fusion_elementwise(Place place) {
for (int axis : {-1, 0, 1, 3}) {
std::unique_ptr<arena::TestCase> tester_add_act(
new FusionElementwiseAddActivationComputeTester(
place, "def", axis, "relu"));
arena::Arena arena_add_act(std::move(tester_add_act), place, 2e-5);
arena_add_act.TestPrecision();
std::unique_ptr<arena::TestCase> tester_sub_act(
new FusionElementwiseSubActivationComputeTester(
place, "def", axis, "relu"));
arena::Arena arena_sub_act(std::move(tester_sub_act), place, 2e-5);
arena_sub_act.TestPrecision();
std::unique_ptr<arena::TestCase> tester_mul_act(
new FusionElementwiseMulActivationComputeTester(
place, "def", axis, "relu"));
arena::Arena arena_mul_act(std::move(tester_mul_act), place, 2e-5);
arena_mul_act.TestPrecision();
std::unique_ptr<arena::TestCase> tester_max_act(
new FusionElementwiseMaxActivationComputeTester(
place, "def", axis, "relu"));
arena::Arena arena_max_act(std::move(tester_max_act), place, 2e-5);
arena_max_act.TestPrecision();
void TestEltTypes(Place place, float abs_error) {
for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0);
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1);
}
}
std::unique_ptr<arena::TestCase> tester_div_act(
new FusionElementwiseDivActivationComputeTester(
place, "def", axis, "relu"));
arena::Arena arena_div_act(std::move(tester_div_act), place, 2e-5);
arena_div_act.TestPrecision();
void TestEltFuseAct(Place place, float abs_error) {
for (auto elt_type :
std::vector<std::string>{"add", "sub", "mul", "div", "max"}) {
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {2, 3, 4, 5}, 0, "relu");
TestElt(place, abs_error, elt_type, {2, 3, 4, 5}, {3}, 1, "relu");
}
}
TEST(FusionElementwise, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_fusion_elementwise(place);
TEST(Elementwise, precision) {
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // use fp16 in npu
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
}
#ifdef LITE_WITH_XPU
TEST(Elementwise_XPU, precision) {
Place place(TARGET(kXPU));
for (int axis : {-1, 1}) {
std::unique_ptr<arena::TestCase> tester(
new ElementwiseComputeTester(place, "def", axis));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
// TestEltDims(place, abs_error);
TestEltTypes(place, abs_error);
TestEltFuseAct(place, abs_error);
}
#endif
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册