未验证 提交 91f0ef0b 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU] Add instance_norm op bridge and unit test, refine the registration...

[LITE][NPU] Add instance_norm op bridge and unit test, refine the registration of op bridges (#2747)
上级 0c0a8a94
......@@ -40,6 +40,7 @@ lite_cc_library(subgraph_bridge_sqrt_op_npu SRCS sqrt_op.cc DEPS ${npu_subgraph_
lite_cc_library(subgraph_bridge_reduce_mean_op_npu SRCS reduce_mean_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_unsqueeze_op_npu SRCS unsqueeze_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_argmax_op_npu SRCS argmax_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_instance_norm_op_npu SRCS instance_norm_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_npu SRCS dropout_op.cc DEPS ${npu_subgraph_bridge_deps})
set(npu_subgraph_bridges
......@@ -68,6 +69,7 @@ set(npu_subgraph_bridges
subgraph_bridge_reduce_mean_op_npu
subgraph_bridge_unsqueeze_op_npu
subgraph_bridge_argmax_op_npu
subgraph_bridge_instance_norm_op_npu
subgraph_bridge_dropout_op_npu
CACHE INTERNAL "npu_subgraph_bridges")
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
#include "lite/operators/activation_ops.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
void act_ref(const std::shared_ptr<operators::ActivationOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto op_type = op_info->Type();
auto x = scope->FindTensor("x");
auto out = scope->FindMutableTensor("out_ref");
out->Resize(x->dims());
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
CHECK_EQ(x->numel(), out->numel());
// "sigmoid","relu","tanh","relu_clipped","leaky_relu","softsign","hard_sigmoid"
if (op_type == "sigmoid") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = 1.f / (1.f + std::exp(-x_data[i]));
}
} else if (op_type == "relu") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::max(0.f, x_data[i]);
}
} else if (op_type == "tanh") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = (std::exp(x_data[i]) - std::exp(-x_data[i])) /
(std::exp(x_data[i]) + std::exp(-x_data[i]));
}
} else if (op_type == "relu_clipped") {
auto relu_clipped_coef = op_info->GetAttr<float>("Relu_clipped_coef");
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(std::max(0.f, x_data[i]), relu_clipped_coef);
}
} else if (op_type == "relu6") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(std::max(0.f, x_data[i]), 6.f);
}
} else if (op_type == "leaky_relu") {
auto alpha = op_info->GetAttr<float>("alpha");
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::max(x_data[i], x_data[i] * alpha);
}
} else if (op_type == "softsign") {
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = x_data[i] / (1 + std::abs(x_data[i]));
}
} else if (op_type == "hard_sigmoid") {
auto slope = op_info->GetAttr<float>("slope");
auto offset = op_info->GetAttr<float>("offset");
for (size_t i = 0; i < out->numel(); i++) {
out_data[i] = std::min(1.f, slope * x_data[i] + offset);
out_data[i] = std::max(0.f, out_data[i]);
}
} else {
LOG(FATAL) << "unsupported activation type: " << op_type;
}
}
void test_act(std::vector<int64_t> x_shape, std::string op_type) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_var_name("out");
std::string out_ref_var_name("out_ref");
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize(x_shape);
// initialize input&output data
FillTensor<float>(x, -8, 8);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType(op_type);
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
if (op_type == "relu_clipped") {
opdesc.SetAttr("Relu_clipped_coef", 3.f);
} else if (op_type == "relu6") {
opdesc.SetAttr("Relu_clipped_coef", 6.f);
} else if (op_type == "leaky_relu") {
opdesc.SetAttr("alpha", 0.02f);
} else if (op_type == "hard_sigmoid") {
opdesc.SetAttr("slope", 0.2f);
opdesc.SetAttr("offset", 0.5f);
}
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ActivationOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
// execute reference implementation and save to output tensor
act_ref(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, activation) {
std::vector<std::vector<int64_t>> shapes{{1}, {2, 3}, {1, 2, 3, 4}};
std::vector<std::string> types{"sigmoid",
"relu",
"tanh",
"relu_clipped",
"relu6",
"leaky_relu",
"softsign",
"hard_sigmoid"};
for (auto x_shape : shapes) {
for (auto op_type : types) {
test_act(x_shape, op_type);
}
}
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(sigmoid);
USE_NPU_BRIDGE(sigmoid);
USE_LITE_OP(relu);
USE_NPU_BRIDGE(relu);
USE_LITE_OP(tanh);
USE_NPU_BRIDGE(tanh);
USE_LITE_OP(relu_clipped);
USE_NPU_BRIDGE(relu_clipped);
USE_LITE_OP(relu6);
USE_NPU_BRIDGE(relu6);
USE_LITE_OP(leaky_relu);
USE_NPU_BRIDGE(leaky_relu);
USE_LITE_OP(softsign);
USE_NPU_BRIDGE(softsign);
USE_LITE_OP(hard_sigmoid);
USE_NPU_BRIDGE(hard_sigmoid);
......@@ -62,7 +62,8 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto groups = op_info->GetAttr<int>("groups");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu = op_info->GetAttr<bool>("fuse_relu");
auto fuse_relu =
op_info->HasAttr("fuse_relu") && op_info->GetAttr<bool>("fuse_relu");
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(dilations.size(), 2L);
......
......@@ -53,7 +53,8 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto groups = op_info->GetAttr<int>("groups");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu = op_info->GetAttr<bool>("fuse_relu");
auto fuse_relu =
op_info->HasAttr("fuse_relu") && op_info->GetAttr<bool>("fuse_relu");
CHECK_EQ(strides.size(), 2L);
CHECK_EQ(dilations.size(), 2L);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
int InstanceNormConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
auto x_dims = x->dims();
CHECK_EQ(x_dims.size(), 4L);
auto batch_size = x_dims[0];
auto channel_size = x_dims[1];
auto spatial_size = x_dims[2] * x_dims[3];
DDim scale_bias_dims({1, channel_size, 1, 1});
auto y_name = op_info->Output("Y").front();
auto y_type = kernel->GetOutputDeclType("Y");
CHECK(y_type->precision() == PRECISION(kFloat));
CHECK(y_type->layout() == DATALAYOUT(kNCHW));
float epsilon = op_info->GetAttr<float>("epsilon");
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Bias node
std::shared_ptr<Node> bias_node = nullptr;
if (HasInputArg(op_info, scope, "Bias")) {
auto bias_name = op_info->Input("Bias").front();
auto bias_type = kernel->GetInputDeclType("Bias");
CHECK(bias_type->precision() == PRECISION(kFloat));
CHECK(bias_type->layout() == DATALAYOUT(kNCHW));
auto bias = scope->FindMutableTensor(bias_name);
auto bias_dims = bias->dims();
CHECK_EQ(channel_size, bias_dims.production());
if (spatial_size <= 1) {
// Bug exists in HiAI DDK when h=1 and w=1
auto bias_data = bias->mutable_data<float>();
Tensor y;
y.Resize(x_dims);
y.set_persistable(true);
auto y_data = y.mutable_data<float>();
for (int i = 0; i < batch_size; i++) {
std::memcpy(y_data, bias_data, sizeof(float) * channel_size);
y_data += channel_size;
}
graph->Add(y_name, y);
return SUCCESS;
} else {
if (!bias->persistable()) {
LOG(WARNING) << "[NPU] Only supporting persistable bias tensor.";
bias->set_persistable(true);
}
bias_node = graph->Add(bias_name, *bias, scale_bias_dims);
}
} else {
if (spatial_size <= 1) {
// Bug exists in HiAI DDK when h=1 and w=1
graph->Add(y_name, 0.0f, x_dims);
return SUCCESS;
} else {
bias_node = graph->Add(y_name + "/bias", 0.0f, scale_bias_dims);
}
}
// Scale node
std::shared_ptr<Node> scale_node = nullptr;
if (HasInputArg(op_info, scope, "Scale")) {
auto scale_name = op_info->Input("Scale").front();
auto scale_type = kernel->GetInputDeclType("Scale");
CHECK(scale_type->precision() == PRECISION(kFloat));
CHECK(scale_type->layout() == DATALAYOUT(kNCHW));
auto scale = scope->FindMutableTensor(scale_name);
auto scale_dims = scale->dims();
CHECK_EQ(channel_size, scale_dims.production());
if (!scale->persistable()) {
LOG(WARNING) << "[NPU] Only supporting persistable scale tensor.";
scale->set_persistable(true);
}
scale_node = graph->Add(scale_name, *scale, scale_bias_dims);
} else {
scale_node = graph->Add(y_name + "/scale", 1.0f, scale_bias_dims);
}
// InstanceNorm node
auto instance_norm_node = graph->Add<ge::op::InstanceNorm>(y_name);
auto instance_norm_op = instance_norm_node->data<ge::op::InstanceNorm>();
instance_norm_op->set_input_x(*x_node->data());
instance_norm_op->set_input_scale(*scale_node->data());
instance_norm_op->set_input_bias(*bias_node->data());
instance_norm_op->set_attr_reduction_indices(
ge::AttrValue::LIST_INT({0, 1, 2}));
instance_norm_op->set_attr_epsilon(epsilon);
return SUCCESS;
}
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(instance_norm,
kNPU,
paddle::lite::subgraph::npu::InstanceNormConverter);
......@@ -32,7 +32,7 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) {
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x_type = kernel->GetInputDeclType("Input");
auto x_type = kernel->GetInputDeclType("X");
CHECK(x_type->precision() == PRECISION(kFloat));
CHECK(x_type->layout() == DATALAYOUT(kNCHW));
auto x = scope->FindMutableTensor(x_name);
......@@ -58,25 +58,34 @@ int Pad2dConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto padding_node = graph->Add(out_name + "/padding", padding, {xds, 2});
// Pad node
auto pad2d_node = graph->Add<ge::op::Pad>(out_name);
auto pad2d_op = pad2d_node->data<ge::op::Pad>();
pad2d_op->set_input_x(*x_node->data());
pad2d_op->set_input_padding(*padding_node->data());
auto mode = op_info->GetAttr<std::string>("mode");
if (mode == "constant") {
auto pad2d_node = graph->Add<ge::op::PadV2>(out_name);
auto pad2d_op = pad2d_node->data<ge::op::PadV2>();
pad2d_op->set_input_x(*x_node->data());
pad2d_op->set_input_paddings(*padding_node->data());
// Pad value node
auto pad_value = op_info->GetAttr<float>("pad_value");
auto pad_value_node = graph->Add(out_name + "/pad_value", pad_value);
pad2d_op->set_input_constant_values(*pad_value_node->data());
pad2d_op->set_attr_T(0); // type of pad_value: 0:float 3:int32
pad2d_op->set_attr_mode(0);
} else if (mode == "reflect") {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
pad2d_op->set_attr_mode(1);
return FAILED;
} else {
LOG(WARNING) << "[NPU] pad mode " << mode << " isn't supported in HiAI DDK";
return FAILED;
auto pad2d_node = graph->Add<ge::op::Pad>(out_name);
auto pad2d_op = pad2d_node->data<ge::op::Pad>();
pad2d_op->set_input_x(*x_node->data());
pad2d_op->set_input_padding(*padding_node->data());
if (mode == "reflect") {
pad2d_op->set_attr_mode(1);
LOG(WARNING) << "[NPU] pad mode " << mode
<< " isn't supported in HiAI DDK";
} else if (mode == "edge") {
pad2d_op->set_attr_mode(3);
LOG(WARNING) << "[NPU] pad mode " << mode
<< " isn't supported in HiAI DDK";
} else {
LOG(WARNING) << "[NPU] pad mode " << mode
<< " isn't supported in HiAI DDK";
return FAILED;
}
}
return REBUILD_WHEN_SHAPE_CHANGED;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/pad2d_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename dtype>
void pad2d_ref(const std::shared_ptr<operators::Pad2dOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindMutableTensor(op_info->Input("X").front());
auto out = scope->FindMutableTensor(op_info->Output("Out").front());
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
int pad_top = paddings[0];
int pad_bottom = paddings[1];
int pad_left = paddings[2];
int pad_right = paddings[3];
auto mode = op_info->GetAttr<std::string>("mode");
int pad_mode;
if (mode == "constant") {
pad_mode = 0;
} else if (mode == "reflect") {
pad_mode = 1;
} else if (mode == "edge") {
pad_mode = 2;
} else {
LOG(FATAL) << "Unknown mode type";
}
float pad_value = op_info->GetAttr<float>("pad_value");
auto out_dims = out->dims();
int n = out_dims[0];
int c = out_dims[1];
int h = out_dims[2];
int w = out_dims[3];
int in_w = w - pad_left - pad_right;
int in_h = h - pad_bottom - pad_top;
int spatial_size_out = w * h;
int spatial_size_in = in_w * in_h;
auto x_data = x->data<float>();
auto out_data = out->mutable_data<float>();
#pragma omp parallel for
for (int i = 0; i < n * c; ++i) {
const float* din_batch = x_data + i * spatial_size_in;
float* dout_batch = out_data + i * spatial_size_out;
int in_y = 0;
int in_x = 0;
for (int y = 0; y < h; ++y) {
for (int x = 0; x < w; ++x) {
switch (pad_mode) {
case 0:
in_y = y - pad_top;
in_x = x - pad_left;
dout_batch[y * w + x] =
(in_x >= 0 && in_x < in_w) && (in_y >= 0 && in_y < in_h)
? din_batch[in_y * in_w + in_x]
: pad_value;
break;
case 1:
in_x =
std::min(std::max(pad_left, x), in_w + pad_left - 1) - pad_left;
in_y = std::min(std::max(pad_top, y), in_h + pad_top - 1) - pad_top;
dout_batch[y * w + x] = din_batch[in_y * in_w + in_x];
break;
case 2:
in_y = y - pad_top;
in_x = x - pad_left;
in_y = std::max(in_y, -in_y);
in_y = std::min(in_y, 2 * in_h - in_y - 2);
in_x = std::max(in_x, -in_x);
in_x = std::min(in_x, 2 * in_w - in_x - 2);
dout_batch[y * w + x] = din_batch[in_y * in_w + in_x];
break;
default:
LOG(ERROR) << "ERROR: unknown pad mode:" << pad_mode;
}
}
}
}
}
void test_pad2d(int bs,
int ic,
int ih,
int iw,
std::vector<int> paddings,
float pad_value,
std::string mode) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.NewTensor(x_var_name);
auto* out = scope.NewTensor(out_var_name);
auto* out_ref = scope.NewTensor(out_ref_var_name);
x->Resize({bs, ic, ih, iw});
// initialize input&output data
// FillTensor<float, int>(x);
auto x_data = x->mutable_data<float>();
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("pad2d");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("paddings", paddings);
opdesc.SetAttr("pad_value", pad_value);
opdesc.SetAttr("mode", mode);
opdesc.SetAttr("data_format", std::string("NCHW"));
auto op = CreateOp<operators::Pad2dOpLite>(opdesc, &scope);
pad2d_ref<float>(op);
out_ref->CopyDataFrom(*out);
LauchOp(op, {x_var_name}, {out_var_name});
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2) << "-----" << i;
}
}
TEST(NPUBridges, pad2d) {
#if 1
for (auto bs : {1, 4, 7}) {
for (auto ic : {1, 4, 7}) {
for (auto ih : {1, 4, 7}) {
for (auto iw : {1, 4, 7}) {
for (auto paddings : {/*std::vector<int>{0, 0, 0, 0},*/
std::vector<int>{0, 0, 0, 1},
std::vector<int>{0, 1, 0, 2},
std::vector<int>{1, 2, 3, 4}}) {
// npu not support pad_value!=0
for (auto pad_value : {0.f /*,1.f*/}) {
// npu only support constant
for (auto mode : {"constant" /*, "reflect", "edge"*/}) {
if (mode == "edge") continue;
VLOG(3) << "bs: " << bs << " ic: " << ic << " ih: " << ih
<< " iw: " << iw << " paddings: {" << paddings[0]
<< "," << paddings[1] << "," << paddings[2] << ","
<< paddings[3] << "}"
<< " pad_value: " << pad_value << " mode: " << mode;
test_pad2d(bs, ic, ih, iw, paddings, pad_value, mode);
}
}
}
}
}
}
}
#else
test_pad2d(1, 1, 1, 1, {0, 0, 0, 1}, 0, "constant");
#endif
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(pad2d);
USE_NPU_BRIDGE(pad2d);
......@@ -54,3 +54,4 @@ USE_SUBGRAPH_BRIDGE(transpose, kNPU);
USE_SUBGRAPH_BRIDGE(transpose2, kNPU);
USE_SUBGRAPH_BRIDGE(unsqueeze, kNPU);
USE_SUBGRAPH_BRIDGE(unsqueeze2, kNPU);
USE_SUBGRAPH_BRIDGE(instance_norm, kNPU);
......@@ -25,26 +25,29 @@ Registry& Registry::Instance() {
}
void Registry::Insert(const std::string& op_type,
const std::string& target,
const TargetType& target,
const cvt_func_type& cvt_func_name) {
auto it = map_.find(target);
int key = static_cast<int>(target);
auto it = map_.find(key);
if (it == map_.end()) {
map_.insert(std::make_pair(
target, std::unordered_map<std::string, cvt_func_type>()));
map_.insert(
std::make_pair(key, std::unordered_map<std::string, cvt_func_type>()));
}
map_.at(target).insert(std::make_pair(op_type, cvt_func_name));
map_.at(key).insert(std::make_pair(op_type, cvt_func_name));
}
const cvt_func_type& Registry::Select(const std::string& op_type,
const std::string& target) const {
return map_.at(target).at(op_type);
const TargetType& target) const {
int key = static_cast<int>(target);
return map_.at(key).at(op_type);
}
bool Registry::Exists(const std::string& op_type,
const std::string& target) const {
bool found = map_.find(target) != map_.end();
const TargetType& target) const {
int key = static_cast<int>(target);
bool found = map_.find(key) != map_.end();
if (found) {
found = map_.at(target).find(op_type) != map_.at(target).end();
found = map_.at(static_cast<int>(key)).find(op_type) != map_.at(key).end();
}
return found;
}
......
......@@ -36,18 +36,17 @@ inline bool CHECK_REBUILD_WHEN_SHAPE_CHANGED(int status) {
using cvt_func_type =
std::function<int(void* ctx, OpLite* op, KernelBase* kernel)>;
using cvt_map_type =
std::unordered_map<std::string,
std::unordered_map<std::string, cvt_func_type>>;
std::unordered_map<int, std::unordered_map<std::string, cvt_func_type>>;
class Registry {
public:
static Registry& Instance();
void Insert(const std::string& op_type,
const std::string& target,
const TargetType& target,
const cvt_func_type& cvt_func_name);
const cvt_func_type& Select(const std::string& op_type,
const std::string& target) const;
bool Exists(const std::string& op_type, const std::string& target) const;
const TargetType& target) const;
bool Exists(const std::string& op_type, const TargetType& target) const;
Registry() = default;
private:
......@@ -80,7 +79,7 @@ class Registry {
"once!"); \
int __reg_subgraph_bridge_##op_type__##_##target__##_Insert() { \
paddle::lite::subgraph::Registry::Instance().Insert( \
#op_type__, #target__, cvt_func_name); \
#op_type__, TARGET(target__), cvt_func_name); \
return 0; \
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/test_helper.h"
#include <utility>
#include "lite/backends/npu/builder.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/operators/graph_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
void LauchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names) {
auto scope = op->scope();
auto op_type = op->op_info()->Type();
// convert op to IR graph
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
const auto& supported_lists = bridges.AllFunctions();
CHECK(bridges.HasType(op_type));
node_map_type inputs_map;
for (auto input_var_name : input_var_names) {
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
ge::TensorDesc input_desc(
ge::Shape(input->dims().Vectorize()), ge::FORMAT_NCHW, ge::DT_FLOAT);
auto input_node = std::make_shared<ge::op::Data>(input_var_name);
input_node->update_input_desc_x(input_desc);
lite::npu::OpList::Global().add(input_node);
inputs_map[input_var_name] = input_node;
}
auto outputs_map = supported_lists.at(op_type)(op, inputs_map);
CHECK_GT(outputs_map.size(), 0);
// compile IR graph to om model
std::vector<ge::Operator> graph_inputs;
for (auto input_var_name : input_var_names) {
graph_inputs.push_back(*inputs_map[input_var_name]);
}
std::vector<ge::Operator> graph_outputs;
for (auto output_var_name : output_var_names) {
graph_outputs.push_back(*outputs_map[output_var_name]);
}
std::string weight_var_name = "weight";
auto weight = scope->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
CHECK(lite::npu::BuildModel(graph_inputs, graph_outputs, weight));
CHECK_GT(weight->numel(), 0);
CHECK_NE(weight->data<uint8_t>(), 0);
// create graph op and set inputs and outputs
cpp::OpDesc graph_op_desc;
graph_op_desc.SetType("graph_op");
graph_op_desc.SetInput("Inputs", input_var_names);
graph_op_desc.SetInput("Weight", {weight_var_name});
graph_op_desc.SetOutput("Outputs", output_var_names);
auto graph_op =
std::make_shared<operators::GraphOpLite>(graph_op_desc.Type());
graph_op->SetValidPlaces({Place{TARGET(kNPU), PRECISION(kFloat)}});
CHECK(graph_op->Attach(graph_op_desc, scope));
CHECK(graph_op->CheckShape());
CHECK(graph_op->InferShape());
// create graph op kernel and set NPU context
auto graph_kernels =
graph_op->CreateKernels({Place{TARGET(kNPU), PRECISION(kFloat)}});
CHECK(!graph_kernels.empty());
auto graph_kernel =
std::move(graph_kernels.front()); // use the first kernel by default
auto graph_ctx = ContextScheduler::Global().NewContext(TARGET(kNPU));
graph_kernel->SetContext(std::move(graph_ctx));
// perform graph op kernel and store to output variables
graph_kernel->Launch();
// release all of resources of generated model
lite::npu::OpList::Global().clear();
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(graph_op);
USE_LITE_KERNEL(graph_op, kNPU, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
template <typename T>
std::shared_ptr<T> CreateOp(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto op = std::make_shared<T>(opdesc.Type());
op->SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kNPU), PRECISION(kFloat)}});
CHECK(op->Attach(opdesc, scope));
CHECK(op->CheckShape());
CHECK(op->InferShape());
return op;
}
// T is the target data type
// R is the range data type, e.g. int, half
template <typename T, typename R = float>
void FillTensor(Tensor* x,
T lower = static_cast<T>(-2),
T upper = static_cast<T>(2)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
T* x_data = x->mutable_data<T>();
for (int i = 0; i < x->dims().production(); ++i) {
auto r = uniform_dist(rng) * (upper - lower) + lower;
x_data[i] = static_cast<T>(static_cast<R>(r));
}
}
void LauchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names);
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -30,29 +30,69 @@
// Extended ops based on HIAI DDK
namespace ge {
/**
/*
* Pads a tensor.
* <Input>
* x : the input tensor
* padding : the input tensor must be 2-D
* constant_values : constant values must be a scalar
* x : the input tensor
* padding : the input tensor must be 2-D
* constant_values : constant values must be a scalar
* <Output>
* output : the output tensor
* y : the output tensor
* <Attr>
* t_paddings : Default DT_INT32 , t_paddings must be the same with
* datatype of the padding
* mode : 0: CONSTANT, 1: REFLECT, 2: SYMMETRIC
* T : datatype of constant_values DT_INT32:3 DT_FLOAT:0
* mode : 0: CONSTANT, 1: REFLECT, 2: SYMMETRIC, 3:EDGE.
* <Added in HiAI version>
* 100.320.010.010
*/
REG_OP(Pad)
.INPUT(x, TensorType({DT_FLOAT, DT_INT32}))
.INPUT(padding, TensorType({DT_INT32}))
.OPTIONAL_INPUT(constant_values, TensorType({DT_INT32, DT_FLOAT}))
.OUTPUT(output, TensorType({DT_FLOAT, DT_INT32}))
.ATTR(t_paddings, AttrValue::INT{3})
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32}))
.ATTR(mode, AttrValue::INT{0})
.REQUIRED_ATTR(T, AttrValue::INT)
.OP_END();
.OP_END()
/*
* The operation pads input according to the paddings and constant_values.
* <Input>
* x : The input tensor.
* paddings : The values of paddings, as a role of dimensions to be added
* on the input tensor x, must be a Const-OP. constant_values : A tensor of
* the same type as x, that indicates the value to use for padding input,
* must be a Const-OP.
* <Output>
* y : The output tensor.
* <Added in HiAI version>
* 100.320.010.010
*/
REG_OP(PadV2)
.INPUT(x, TensorType({DT_FLOAT, DT_INT32}))
.INPUT(paddings, TensorType({DT_INT32}))
.INPUT(constant_values, TensorType({DT_FLOAT, DT_INT32}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32}))
.OP_END()
/*
* Computes instance norm
* <Input>
* x : Input tensor which supports 4D dimension format.
* scale : A tesnor, multiple to result
* bias : A tensor, add to result
* <Output>
* y : Output tensor
* <Attr>
* reduction_indices : The dimensions to reduce
* epsilon : A very small float number used to avoid dividing by zero.
* <Added in HiAI version>
* 100.320.010.010
*/
REG_OP(InstanceNorm)
.INPUT(x, TensorType({DT_FLOAT}))
.INPUT(scale, TensorType({DT_FLOAT}))
.INPUT(bias, TensorType({DT_FLOAT}))
.OUTPUT(y, TensorType({DT_FLOAT}))
.REQUIRED_ATTR(reduction_indices, AttrValue::LIST_INT)
.ATTR(epsilon, AttrValue::FLOAT{1e-7f})
.OP_END()
} // namespace ge
......
......@@ -39,11 +39,12 @@ int SubgraphEngine::BuildDeviceProgram() {
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists(op_type, "kNPU")) {
if (!bridges.Exists(op_type, TARGET(kNPU))) {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select(op_type, "kNPU")(reinterpret_cast<void*>(&graph),
status |=
bridges.Select(op_type, TARGET(kNPU))(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/softmax_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
template <typename dtype>
void softmax_ref(const std::shared_ptr<operators::SoftmaxOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_data = x->data<dtype>();
auto out_data = out->mutable_data<dtype>();
DDim x_dims = x->dims();
auto x_rank = x_dims.size();
int axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis += x_rank;
}
int axis_size = x_dims[axis];
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int start = idx_outer * inner_num + idx_inner;
int offset;
offset = start;
dtype max_data = std::numeric_limits<dtype>::lowest();
for (int j = 0; j < axis_size; j++) {
max_data = x_data[offset] > max_data ? x_data[offset] : max_data;
offset += inner_num;
}
offset = start;
dtype sum_data = (dtype)0;
for (int j = 0; j < axis_size; j++) {
out_data[offset] = exp(x_data[offset] - max_data);
sum_data += out_data[offset];
offset += inner_num;
}
offset = start;
for (int j = 0; j < axis_size; j++) {
out_data[offset] /= sum_data;
offset += inner_num;
}
}
}
void test_softmax(int bs, int ic, int ih, int iw, int axis) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("softmax");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
// create and convert op to XPU model, then run it on XPU
auto op = CreateOp<operators::SoftmaxOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
softmax_ref<float>(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
}
TEST(XPUBridges, softmax) {
for (auto bs : {2, 3}) {
for (auto ic : {4}) {
for (auto ih : {5}) {
for (auto iw : {6}) {
for (auto axis : {-3, -1, 0, 1, 2, 3}) {
test_softmax(bs, ic, ih, iw, axis);
}
}
}
}
}
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(softmax);
USE_XPU_BRIDGE(softmax);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/bridges/test_helper.h"
#include <utility>
#include "lite/backends/xpu/builder.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/xpu/bridges/registry.h"
#include "lite/operators/graph_op.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
void LauchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names) {
auto scope = op->scope();
auto op_type = op->op_info()->Type();
// convert lite op to XPU op
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& supported_lists = bridges.AllFunctions();
CHECK(bridges.HasType(op_type));
graph_ctx_type graph_ctx;
graph_ctx.builder = std::make_shared<xtcl::network::xNetworkBuilder>();
graph_ctx.params =
std::make_shared<xtcl::network::xTensorCompiler::ParamNDArrayMap>();
node_map_type input_nodes;
for (auto input_var_name : input_var_names) {
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_node = std::make_shared<xtcl::xExpr>(
graph_ctx.builder->CreateTensor(input_var_name,
lite::xpu::CvtShape(input->dims()),
::xtcl::Float(32)));
input_nodes[input_var_name] = input_node;
}
auto output_nodes = supported_lists.at(op_type)(op, &graph_ctx, input_nodes);
CHECK_GT(output_nodes.size(), 0);
// build network graph and output model data
std::vector<std::shared_ptr<xtcl::xExpr>> ordered_output_nodes;
for (auto output_var_name : output_var_names) {
ordered_output_nodes.push_back(output_nodes.at(output_var_name));
}
std::string weight_var_name = "weight";
auto weight = scope->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
CHECK(lite::xpu::BuildModel(
graph_ctx.builder, graph_ctx.params, &ordered_output_nodes, weight));
CHECK_GT(weight->numel(), 0);
CHECK(weight->data<uint8_t>() != nullptr);
// create graph op and set inputs and outputs
cpp::OpDesc graph_op_desc;
graph_op_desc.SetType("graph_op");
graph_op_desc.SetInput("Inputs", input_var_names);
graph_op_desc.SetInput("Weight", {weight_var_name});
graph_op_desc.SetOutput("Outputs", output_var_names);
auto graph_op =
std::make_shared<operators::GraphOpLite>(graph_op_desc.Type());
graph_op->SetValidPlaces({Place{TARGET(kXPU), PRECISION(kFloat)}});
CHECK(graph_op->Attach(graph_op_desc, scope));
CHECK(graph_op->CheckShape());
CHECK(graph_op->InferShape());
// create graph op kernel and set XPU context
auto graph_kernels =
graph_op->CreateKernels({Place{TARGET(kXPU), PRECISION(kFloat)}});
CHECK(!graph_kernels.empty());
auto graph_kernel =
std::move(graph_kernels.front()); // use the first kernel by default
auto graph_device = ContextScheduler::Global().NewContext(TARGET(kXPU));
graph_kernel->SetContext(std::move(graph_device));
// perform graph op kernel and store to output variables
graph_kernel->Launch();
lite::xpu::DeviceInfo::Global().Clear();
}
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_OP(graph_op);
USE_LITE_KERNEL(graph_op, kXPU, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <random>
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
namespace bridges {
template <typename T>
std::shared_ptr<T> CreateOp(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto op = std::make_shared<T>(opdesc.Type());
op->SetValidPlaces({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kXPU), PRECISION(kFloat)}});
CHECK(op->Attach(opdesc, scope));
CHECK(op->CheckShape());
CHECK(op->InferShape());
return op;
}
// T is the target data type
// R is the range data type, e.g. int, half
template <typename T, typename R = float>
void FillTensor(Tensor* x,
T lower = static_cast<T>(-2),
T upper = static_cast<T>(2)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
T* x_data = x->mutable_data<T>();
for (int i = 0; i < x->dims().production(); ++i) {
auto r = uniform_dist(rng) * (upper - lower) + lower;
x_data[i] = static_cast<T>(static_cast<R>(r));
}
}
void LauchOp(const std::shared_ptr<lite::OpLite> op,
const std::vector<std::string>& input_var_names,
const std::vector<std::string>& output_var_names);
} // namespace bridges
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -39,11 +39,12 @@ int SubgraphEngine::BuildDeviceProgram() {
op->CheckShape();
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists(op_type, "kXPU")) {
if (!bridges.Exists(op_type, TARGET(kXPU))) {
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select(op_type, "kXPU")(reinterpret_cast<void*>(&graph),
status |=
bridges.Select(op_type, TARGET(kXPU))(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
......
......@@ -14,7 +14,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
......@@ -24,8 +24,8 @@ namespace lite {
class InstanceNormComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string output_ = "y";
std::string x_ = "x";
std::string y_ = "y";
std::string saved_mean_ = "saved_mean";
std::string saved_variance_ = "saved_variance";
std::string scale_ = "scale";
......@@ -42,24 +42,24 @@ class InstanceNormComputeTest : public arena::TestCase {
: TestCase(place, alias), dims_(dims), epsilon_(epsilon) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(input_);
auto x = scope->FindTensor(x_);
auto scale = scope->FindTensor(scale_);
auto bias = scope->FindTensor(bias_);
auto out = scope->NewTensor(output_);
auto y = scope->NewTensor(y_);
auto saved_mean = scope->NewTensor(saved_mean_);
auto saved_variance = scope->NewTensor(saved_variance_);
CHECK(out);
CHECK(y);
CHECK(saved_mean);
CHECK(saved_variance);
DDim saved_dim({dims_[0] * dims_[1]});
out->Resize(dims_);
y->Resize(dims_);
saved_mean->Resize(saved_dim);
saved_variance->Resize(saved_dim);
auto x_data = x->data<float>();
auto scale_data = scale->data<float>();
auto bias_data = bias->data<float>();
auto out_data = out->mutable_data<float>();
auto y_data = y->mutable_data<float>();
auto saved_mean_data = saved_mean->mutable_data<float>();
auto saved_variance_data = saved_variance->mutable_data<float>();
......@@ -89,46 +89,47 @@ class InstanceNormComputeTest : public arena::TestCase {
// compute out
for (int i = 0; i < n * c; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float* out_ptr = out_data + i * spatial_size;
float* y_ptr = y_data + i * spatial_size;
float scale_val = scale_data[i % c];
float bias_val = bias_data[i % c];
for (int j = 0; j < spatial_size; ++j) {
out_ptr[j] = scale_val * (x_ptr[j] - saved_mean_data[i]) *
saved_variance_data[i] +
bias_val;
y_ptr[j] = scale_val * (x_ptr[j] - saved_mean_data[i]) *
saved_variance_data[i] +
bias_val;
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("instance_norm");
op_desc->SetInput("X", {input_});
op_desc->SetInput("X", {x_});
op_desc->SetInput("Bias", {bias_});
op_desc->SetInput("Scale", {scale_});
op_desc->SetOutput("Y", {output_});
op_desc->SetOutput("Y", {y_});
op_desc->SetOutput("SavedMean", {saved_mean_});
op_desc->SetOutput("SavedVariance", {saved_variance_});
op_desc->SetAttr("epsilon", epsilon_);
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
DDim scale_dim{{dims_[1]}};
std::vector<float> scale(scale_dim.production());
fill_data_rand(scale.data(), -1.f, 1.f, scale_dim.production());
std::vector<float> bias(scale_dim.production());
fill_data_rand(bias.data(), -1.f, 1.f, scale_dim.production());
SetCommonTensor(input_, dims_, din.data());
SetCommonTensor(scale_, scale_dim, scale.data());
SetCommonTensor(bias_, scale_dim, bias.data());
std::vector<float> x(dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, dims_.production());
DDim scale_bias_dims{{dims_[1]}};
std::vector<float> scale(scale_bias_dims.production());
fill_data_rand(scale.data(), -1.f, 1.f, scale_bias_dims.production());
std::vector<float> bias(scale_bias_dims.production());
fill_data_rand(bias.data(), -1.f, 1.f, scale_bias_dims.production());
SetCommonTensor(x_, dims_, x.data());
SetCommonTensor(scale_, scale_bias_dims, scale.data());
SetCommonTensor(bias_, scale_bias_dims, bias.data());
}
};
void test_instance_norm(Place place) {
void TestInstanceNorm(Place place,
float abs_error = 6e-5,
std::vector<std::string> ignored_outs = {}) {
for (auto& n : {1, 3, 16}) {
for (auto& c : {1, 4, 16}) {
for (auto& h : {1, 16, 33, 56}) {
......@@ -138,11 +139,13 @@ void test_instance_norm(Place place) {
std::unique_ptr<arena::TestCase> tester(
new InstanceNormComputeTest(place, "def", dim_in, epsilon));
#ifdef LITE_WITH_ARM
auto& ctx = tester->context()->As<ARMContext>();
ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 4);
if (place == TARGET(kARM)) {
auto& ctx = tester->context()->As<ARMContext>();
ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 4);
}
#endif
arena::Arena arena(std::move(tester), place, 6e-5);
if (!arena.TestPrecision()) {
arena::Arena arena(std::move(tester), place, abs_error);
if (!arena.TestPrecision(ignored_outs)) {
LOG(ERROR) << "run n: " << n << ", c: " << c << ", h: " << h
<< ", w: " << w;
return;
......@@ -154,10 +157,19 @@ void test_instance_norm(Place place) {
}
TEST(InstanceNorm, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_instance_norm(place);
Place place;
float abs_error = 6e-5;
std::vector<std::string> ignored_outs = {};
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
ignored_outs = {"saved_mean", "saved_variance"};
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
TestInstanceNorm(place, abs_error, ignored_outs);
}
} // namespace lite
......
......@@ -16,6 +16,7 @@
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
......@@ -23,8 +24,8 @@ namespace lite {
class Pad2dComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "X";
std::string output_ = "Out";
std::string x_ = "X";
std::string out_ = "Out";
DDim dims_{{1, 1, 14, 14}};
std::string mode_{"constant"};
std::vector<int> paddings_;
......@@ -46,13 +47,13 @@ class Pad2dComputeTester : public arena::TestCase {
void RunBaseline(Scope* scope) override {
LOG(INFO) << "into runbase";
auto* out = scope->NewTensor(output_);
auto* out = scope->NewTensor(out_);
CHECK(out);
int out_h = dims_[2] + paddings_[0] + paddings_[1];
int out_w = dims_[3] + paddings_[2] + paddings_[3];
out->Resize(lite::DDim({dims_[0], dims_[1], out_h, out_w}));
auto* out_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<float>();
LOG(INFO) << "get nums";
......@@ -125,8 +126,8 @@ class Pad2dComputeTester : public arena::TestCase {
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("pad2d");
op_desc->SetInput("X", {input_});
op_desc->SetOutput("Out", {output_});
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("mode", mode_);
op_desc->SetAttr("pad_value", pad_value_);
op_desc->SetAttr("paddings", paddings_);
......@@ -134,17 +135,13 @@ class Pad2dComputeTester : public arena::TestCase {
}
void PrepareData() override {
std::vector<float> data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
data[i] = i * 1;
}
SetCommonTensor(input_, dims_, data.data());
std::vector<float> x(dims_.production());
fill_data_rand(x.data(), -1.f, 1.f, dims_.production());
SetCommonTensor(x_, dims_, x.data());
}
};
void TestPad2d(const Place& place) {
void TestPad2d(const Place& place, float abs_error = 2e-5) {
std::string data_format = "NCHW";
for (int pad_top : {0, 1}) {
for (int pad_bottom : {0, 1}) {
......@@ -158,7 +155,7 @@ void TestPad2d(const Place& place) {
<< paddings[2] << " " << paddings[3];
std::unique_ptr<arena::TestCase> tester(new Pad2dComputeTester(
place, "def", pad_mode, paddings, pad_value, data_format));
arena::Arena arena(std::move(tester), place, 2e-5);
arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision();
}
}
......@@ -169,13 +166,17 @@ void TestPad2d(const Place& place) {
}
TEST(Scale, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
TestPad2d(place);
Place place;
float abs_error = 2e-5;
#if defined(LITE_WITH_NPU)
place = TARGET(kNPU);
abs_error = 1e-2; // Using fp16 in NPU
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#else
return;
#endif
TestPad2d(place, abs_error);
}
} // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册