未验证 提交 26470600 编写于 作者: J juncaipeng 提交者: GitHub

Upgrade concat and unsqueeze, test=develop (#2378)

* update concat and unsqueeze, test=develop
上级 15eccb9e
...@@ -39,6 +39,11 @@ void ConcatCompute::Run() { ...@@ -39,6 +39,11 @@ void ConcatCompute::Run() {
std::vector<lite::Tensor*> inputs = param.x; std::vector<lite::Tensor*> inputs = param.x;
auto* out = param.output; auto* out = param.output;
int axis = param.axis; int axis = param.axis;
auto* axis_tensor = param.axis_tensor;
if (axis_tensor != nullptr) {
auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0];
}
out->mutable_data<float>(); out->mutable_data<float>();
/// Sometimes direct copies will be faster, this maybe need deeply analysis. /// Sometimes direct copies will be faster, this maybe need deeply analysis.
...@@ -83,5 +88,7 @@ void ConcatCompute::Run() { ...@@ -83,5 +88,7 @@ void ConcatCompute::Run() {
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
concat, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConcatCompute, def) concat, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConcatCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -55,6 +55,10 @@ REGISTER_LITE_KERNEL(unsqueeze, ...@@ -55,6 +55,10 @@ REGISTER_LITE_KERNEL(unsqueeze,
paddle::lite::kernels::host::UnsqueezeCompute, paddle::lite::kernels::host::UnsqueezeCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -65,6 +69,10 @@ REGISTER_LITE_KERNEL(unsqueeze2, ...@@ -65,6 +69,10 @@ REGISTER_LITE_KERNEL(unsqueeze2,
paddle::lite::kernels::host::Unsqueeze2Compute, paddle::lite::kernels::host::Unsqueeze2Compute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("AxesTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("AxesTensorList",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -51,6 +51,11 @@ void ConcatCompute<Dtype>::Run() { ...@@ -51,6 +51,11 @@ void ConcatCompute<Dtype>::Run() {
Tensor* output = param.output; Tensor* output = param.output;
auto* output_data = output->mutable_data<Dtype>(TARGET(kCUDA)); auto* output_data = output->mutable_data<Dtype>(TARGET(kCUDA));
int axis = param.axis; int axis = param.axis;
auto* axis_tensor = param.axis_tensor;
if (axis_tensor != nullptr) {
auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0];
}
int inner_size = 1; int inner_size = 1;
int outer_size = 1; int outer_size = 1;
auto input_dims = input[0]->dims(); auto input_dims = input[0]->dims();
...@@ -97,5 +102,7 @@ REGISTER_LITE_KERNEL(concat, ...@@ -97,5 +102,7 @@ REGISTER_LITE_KERNEL(concat,
paddle::lite::kernels::cuda::ConcatCompute<float>, paddle::lite::kernels::cuda::ConcatCompute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize(); .Finalize();
...@@ -21,5 +21,7 @@ REGISTER_LITE_KERNEL(concat, ...@@ -21,5 +21,7 @@ REGISTER_LITE_KERNEL(concat,
paddle::lite::kernels::x86::ConcatCompute<float>, paddle::lite::kernels::x86::ConcatCompute<float>,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -40,6 +40,11 @@ class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -40,6 +40,11 @@ class ConcatCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
void Run() override { void Run() override {
auto& param = *param_.get_mutable<param_t>(); auto& param = *param_.get_mutable<param_t>();
int64_t axis = static_cast<int64_t>(param.axis); int64_t axis = static_cast<int64_t>(param.axis);
auto* axis_tensor = param.axis_tensor;
if (axis_tensor != nullptr) {
auto* axis_tensor_data = axis_tensor->data<int>();
axis = static_cast<int64_t>(axis_tensor_data[0]);
}
auto x_dims = param.x[0]->dims(); auto x_dims = param.x[0]->dims();
auto out = param.output; auto out = param.output;
if (param.x.size() == 1) { if (param.x.size() == 1) {
......
...@@ -31,14 +31,25 @@ bool ConcatOpLite::InferShape() const { ...@@ -31,14 +31,25 @@ bool ConcatOpLite::InferShape() const {
for (auto p : param_.x) { for (auto p : param_.x) {
input_dims.push_back(p->dims()); input_dims.push_back(p->dims());
} }
size_t axis = static_cast<size_t>(param_.axis);
const size_t n = input_dims.size(); const size_t n = input_dims.size();
CHECK_GT_OR_FALSE(n, 0); CHECK_GT_OR_FALSE(n, 0);
int axis = 0;
if (param_.axis_tensor == nullptr) {
axis = param_.axis;
} else {
auto *axis_tensor_val = param_.axis_tensor->data<int>();
axis = axis_tensor_val[0];
}
if (axis < 0) {
axis += input_dims[0].size();
}
auto &out_dims = input_dims[0]; auto &out_dims = input_dims[0];
size_t in_zero_dims_size = out_dims.size(); size_t in_zero_dims_size = out_dims.size();
for (size_t i = 1; i < n; i++) { for (size_t i = 1; i < n; i++) {
for (size_t j = 0; j < in_zero_dims_size; j++) { for (size_t j = 0; j < in_zero_dims_size; j++) {
if (j == axis) { if (j == static_cast<size_t>(axis)) {
out_dims[axis] += input_dims[i][j]; out_dims[axis] += input_dims[i][j];
} else { } else {
CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]);
...@@ -68,6 +79,17 @@ bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -68,6 +79,17 @@ bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.axis = op_desc.GetAttr<int>("axis"); param_.axis = op_desc.GetAttr<int>("axis");
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "AxisTensor") !=
input_arg_names.end()) {
auto arguments = op_desc.Input("AxisTensor");
if (arguments.size() > 0) {
auto var = scope->FindVar(arguments.front());
if (var != nullptr) {
param_.axis_tensor = var->GetMutable<lite::Tensor>();
}
}
}
return true; return true;
} }
......
...@@ -207,6 +207,7 @@ struct ConcatParam { ...@@ -207,6 +207,7 @@ struct ConcatParam {
std::vector<lite::Tensor*> x{}; std::vector<lite::Tensor*> x{};
lite::Tensor* output{}; lite::Tensor* output{};
int axis{0}; int axis{0};
lite::Tensor* axis_tensor{};
}; };
/// ----------------------- activation operators ---------------------- /// ----------------------- activation operators ----------------------
...@@ -854,6 +855,8 @@ struct UnsqueezeParam { ...@@ -854,6 +855,8 @@ struct UnsqueezeParam {
lite::Tensor* Out{}; lite::Tensor* Out{};
lite::Tensor* XShape{}; lite::Tensor* XShape{};
std::vector<int> axes{}; std::vector<int> axes{};
const lite::Tensor* axes_tensor{};
std::vector<lite::Tensor>* axes_tensor_vct{};
}; };
/// ----------------------- expand operators ---------------------- /// ----------------------- expand operators ----------------------
......
...@@ -63,9 +63,30 @@ bool UnsqueezeOp::CheckShape() const { ...@@ -63,9 +63,30 @@ bool UnsqueezeOp::CheckShape() const {
} }
bool UnsqueezeOp::InferShape() const { bool UnsqueezeOp::InferShape() const {
std::vector<int> unsqueeze_dims = param_.axes; std::vector<int> final_axes;
auto axes = param_.axes;
auto *axes_tensor = param_.axes_tensor;
std::vector<lite::Tensor> axes_tensor_vct;
if (param_.axes_tensor_vct) {
axes_tensor_vct = *(param_.axes_tensor_vct);
}
if (!axes.empty()) {
final_axes = axes;
} else if (axes_tensor != nullptr) {
auto *axes_tensor_data = axes_tensor->data<int>();
final_axes = std::vector<int>(axes_tensor_data,
axes_tensor_data + axes_tensor->numel());
} else if (!axes_tensor_vct.empty()) {
for (int i = 0; i < axes_tensor_vct.size(); i++) {
final_axes.push_back(axes_tensor_vct[i].data<int>()[0]);
}
} else {
LOG(FATAL) << "Input axis error";
}
DDim in_dims = param_.X->dims(); DDim in_dims = param_.X->dims();
DDim out_dim = GetOutputShape(unsqueeze_dims, in_dims); DDim out_dim = GetOutputShape(final_axes, in_dims);
param_.Out->Resize(out_dim); param_.Out->Resize(out_dim);
return true; return true;
} }
...@@ -81,6 +102,29 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -81,6 +102,29 @@ bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
if (opdesc.HasAttr("axes")) { if (opdesc.HasAttr("axes")) {
param_.axes = opdesc.GetAttr<std::vector<int>>("axes"); param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
} }
if (opdesc.HasInput("AxesTensor") && opdesc.Input("AxesTensor").size() > 0) {
auto var = scope->FindVar(opdesc.Input("AxesTensor").front());
if (var != nullptr) {
param_.axes_tensor = var->GetMutable<lite::Tensor>();
VLOG(5) << "load AxesTensor";
}
}
if (opdesc.HasInput("AxesTensorList") &&
opdesc.Input("AxesTensorList").size() > 0) {
auto args = opdesc.Input("AxesTensorList");
/*
for (auto arg : args) {
auto *var = scope->FindVar(arg);
if (var != nullptr) {
param_.axes_tensor_vct.push_back(var->GetMutable<lite::Tensor>());
}
}
*/
auto *var = scope->FindVar(args.front());
param_.axes_tensor_vct = var->GetMutable<std::vector<lite::Tensor>>();
}
CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null."; CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null.";
CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null."; CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null.";
return true; return true;
......
...@@ -22,6 +22,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE ...@@ -22,6 +22,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_XPU) AND (LITE
#lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_increment_compute SRCS increment_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
DDim infer_shape(const std::vector<const Tensor*>& inputs, int in_axis) {
std::vector<DDim> input_dims;
for (auto* tensor : inputs) {
input_dims.push_back(tensor->dims());
}
size_t axis = static_cast<size_t>(in_axis);
DDim out_dims = input_dims[0];
for (size_t i = 1; i < input_dims.size(); i++) {
for (size_t j = 0; j < input_dims[0].size(); j++) {
if (j == axis) {
out_dims[axis] += input_dims[i][j];
} else {
if (out_dims[j] != input_dims[i][j]) {
LOG(FATAL) << "infer shape error.";
}
}
}
}
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
return out_dims;
}
class ConcateComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::vector<std::string> x_vct_{};
std::string out_ = "out";
std::string axis_tensor_ = "axis_tensor";
int axis_ = 0;
bool is_use_axis_tensor_ = false;
int x_num_ = 3;
DDim x_dims_{{2, 3, 4, 5}};
public:
ConcateComputeTester(const Place& place,
const std::string& alias,
int axis,
bool is_use_axis_tensor)
: TestCase(place, alias) {
axis_ = axis;
is_use_axis_tensor_ = is_use_axis_tensor;
}
void RunBaseline(Scope* scope) override {
std::vector<const Tensor*> x_vct;
for (std::string& name : x_vct_) {
x_vct.push_back(scope->FindTensor(name));
}
auto* out = scope->NewTensor(out_);
DDim output_dims = infer_shape(x_vct, axis_);
out->Resize(output_dims);
auto* output_data = out->mutable_data<float>();
int num = x_vct.size();
int rows = 1;
auto dim_0 = x_vct[0]->dims();
for (int i = 0; i < axis_; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int> input_cols(x_vct.size());
for (int i = 0; i < num; ++i) {
int input_i_numel = x_vct[i]->dims().size() == 0 ? 0 : 1;
for (int didx = 0; didx < x_vct[i]->dims().size(); ++didx) {
input_i_numel *= x_vct[i]->dims()[didx];
}
int t_cols = input_i_numel / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = x_vct[j]->data<float>();
for (int k = 0; k < out_rows; ++k) {
memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len,
sizeof(float) * col_len);
}
col_idx += col_len;
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("concat");
op_desc->SetInput("X", x_vct_);
op_desc->SetAttr("axis", axis_);
if (is_use_axis_tensor_) {
op_desc->SetInput("AxisTensor", {axis_tensor_});
}
op_desc->SetOutput("Out", {out_});
}
void PrepareData() override {
for (int n = 0; n < x_num_; n++) {
std::vector<float> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = static_cast<float>(i + n);
}
const std::string x_name = "x_tensor_" + std::to_string(n);
x_vct_.push_back(x_name);
SetCommonTensor(x_name, x_dims_, x_data.data());
}
if (is_use_axis_tensor_) {
SetCommonTensor(axis_tensor_, DDim({1}), &axis_);
LOG(INFO) << "set axis tensor";
}
}
};
TEST(Concat, precision) {
LOG(INFO) << "test concat op, kARM";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
for (int axis : {1, 2}) {
for (bool is_use_axis_tensor : {false, true}) {
LOG(INFO) << "axis:" << axis
<< ", is_use_axis_tensor:" << is_use_axis_tensor;
std::unique_ptr<arena::TestCase> tester(
new ConcateComputeTester(place, "def", axis, is_use_axis_tensor));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
#endif
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
LOG(INFO) << "test concate op, x86";
for (int axis : {1, 2}) {
for (bool is_use_axis_tensor : {false, true}) {
LOG(INFO) << "axis:" << axis
<< ", is_use_axis_tensor:" << is_use_axis_tensor;
std::unique_ptr<arena::TestCase> tester(
new ConcateComputeTester(place, "def", axis, is_use_axis_tensor));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
#endif
}
} // namespace lite
} // namespace paddle
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string>
#include "lite/api/paddle_use_kernels.h" #include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h" #include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h" #include "lite/core/arena/framework.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -25,15 +25,24 @@ class UnsqueezeComputeTester : public arena::TestCase { ...@@ -25,15 +25,24 @@ class UnsqueezeComputeTester : public arena::TestCase {
// common attributes for this op. // common attributes for this op.
std::string x_ = "X"; std::string x_ = "X";
std::string out_ = "Out"; std::string out_ = "Out";
std::string axes_tensor_ = "AxesTensor";
std::vector<std::string> axes_tensor_list_;
std::vector<int> axes_; std::vector<int> axes_;
DDim dims_; DDim dims_;
// input_axes_flag_: 1 for axes, 2 for axes_tensor, 3 for axes_tensor_list
int input_axes_flag_ = 1;
public: public:
UnsqueezeComputeTester(const Place& place, UnsqueezeComputeTester(const Place& place,
const std::string& alias, const std::string& alias,
const std::vector<int>& axes, const std::vector<int>& axes,
DDim dims) DDim dims,
: TestCase(place, alias), axes_(axes), dims_(dims) {} int input_axes_flag)
: TestCase(place, alias), dims_(dims), input_axes_flag_(input_axes_flag) {
for (int v : axes) {
axes_.push_back(v);
}
}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
const auto* input = scope->FindTensor(x_); const auto* input = scope->FindTensor(x_);
...@@ -86,7 +95,15 @@ class UnsqueezeComputeTester : public arena::TestCase { ...@@ -86,7 +95,15 @@ class UnsqueezeComputeTester : public arena::TestCase {
op_desc->SetType("unsqueeze"); op_desc->SetType("unsqueeze");
op_desc->SetInput("X", {x_}); op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_}); op_desc->SetOutput("Out", {out_});
if (input_axes_flag_ == 1) {
op_desc->SetAttr("axes", axes_); op_desc->SetAttr("axes", axes_);
} else if (input_axes_flag_ == 2) {
op_desc->SetInput("AxesTensor", {axes_tensor_});
} else if (input_axes_flag_ == 3) {
op_desc->SetInput("AxesTensorList", axes_tensor_list_);
} else {
LOG(FATAL) << "input input_axes_flag_ error. " << input_axes_flag_;
}
} }
void PrepareData() override { void PrepareData() override {
...@@ -95,6 +112,23 @@ class UnsqueezeComputeTester : public arena::TestCase { ...@@ -95,6 +112,23 @@ class UnsqueezeComputeTester : public arena::TestCase {
in_data[i] = i; in_data[i] = i;
} }
SetCommonTensor(x_, dims_, in_data.data()); SetCommonTensor(x_, dims_, in_data.data());
if (input_axes_flag_ == 2) {
DDim axes_tensor_dim{{static_cast<int>(axes_.size())}};
std::vector<int> axes_tensor_data(axes_.size());
for (int i = 0; i < axes_tensor_dim.production(); i++) {
axes_tensor_data[i] = axes_[i];
}
SetCommonTensor(axes_tensor_, axes_tensor_dim, axes_tensor_data.data());
} else if (input_axes_flag_ == 3) {
std::string name = "axes_tensor_";
for (size_t i = 0; i < axes_.size(); i++) {
name = name + std::to_string(i);
axes_tensor_list_.push_back(name);
std::vector<int> in_data = {axes_[i]};
SetCommonTensor(name, DDim({1}), in_data.data());
}
}
} }
}; };
...@@ -189,15 +223,19 @@ class Unsqueeze2ComputeTester : public arena::TestCase { ...@@ -189,15 +223,19 @@ class Unsqueeze2ComputeTester : public arena::TestCase {
}; };
void test_unsqueeze(Place place) { void test_unsqueeze(Place place) {
for (std::vector<int> axes : {std::vector<int>({}), for (std::vector<int> axes : {std::vector<int>({1}),
std::vector<int>({0, 2}), std::vector<int>({0, 2}),
std::vector<int>({0, -2})}) { std::vector<int>({0, -2})}) {
for (int N : {1}) { for (int N : {1}) {
for (int C : {3}) { for (int C : {3}) {
for (int H : {1}) { for (int H : {1}) {
for (int W : {5}) { for (int W : {5}) {
std::unique_ptr<arena::TestCase> tester(new UnsqueezeComputeTester( for (int input_axes_flag : {1, 2}) {
place, "def", axes, DDim({N, C, H, W}))); LOG(INFO) << N << " " << C << " " << H << " " << W << " "
<< input_axes_flag;
std::unique_ptr<arena::TestCase> tester(
new UnsqueezeComputeTester(
place, "def", axes, DDim({N, C, H, W}), input_axes_flag));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision(); arena.TestPrecision();
} }
...@@ -205,10 +243,11 @@ void test_unsqueeze(Place place) { ...@@ -205,10 +243,11 @@ void test_unsqueeze(Place place) {
} }
} }
} }
}
} }
void test_unsqueeze2(Place place) { void test_unsqueeze2(Place place) {
for (std::vector<int> axes : {std::vector<int>({}), for (std::vector<int> axes : {std::vector<int>({0}),
std::vector<int>({0, 2}), std::vector<int>({0, 2}),
std::vector<int>({0, -2})}) { std::vector<int>({0, -2})}) {
for (int N : {1}) { for (int N : {1}) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册