未验证 提交 3723451b 编写于 作者: Y yiicy 提交者: GitHub

[ARM] add grid_sampler op and ut, test=develop (#2598)

上级 df160500
...@@ -213,7 +213,17 @@ class Arena { ...@@ -213,7 +213,17 @@ class Arena {
} }
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>( auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - timer); std::chrono::high_resolution_clock::now() - timer);
LOG(INFO) << "average duration: " << duration.count() << " ms";
timer = std::chrono::high_resolution_clock::now();
for (int i = 0; i < times; i++) {
tester_->RunBaseline(tester_->baseline_scope());
}
auto duration_basic = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - timer);
LOG(INFO) << "average lite duration: " << duration.count() << " ms";
LOG(INFO) << "average basic duration: " << duration_basic.count() << " ms";
LOG(INFO) << "speed up ratio: lite_speed / basic_speed: "
<< static_cast<float>(duration_basic.count()) / duration.count();
} }
private: private:
......
...@@ -49,6 +49,7 @@ add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_ ...@@ -49,6 +49,7 @@ add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_
add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(dropout_compute_arm ARM basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(layout_compute_arm ARM basic SRCS layout_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(instance_norm_compute_arm ARM basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(grid_sampler_compute_arm ARM basic SRCS grid_sampler_compute.cc DEPS ${lite_kernel_deps} math_arm)
## 2.other basic kernels: basic kernels that not used in basic models ## 2.other basic kernels: basic kernels that not used in basic models
add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(negative_compute_arm ARM extra SRCS negative_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/grid_sampler_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void GridSamplerCompute::PrepareForRun() {}
void GridSamplerCompute::Run() {
auto& param = this->Param<param_t>();
auto n = param.x->dims()[0];
auto c = param.x->dims()[1];
auto h = param.x->dims()[2];
auto w = param.x->dims()[3];
const float* in = param.x->data<float>();
const float* grid = param.grid->data<float>();
float* out = param.out->mutable_data<float>();
auto& ctx = this->ctx_->template As<ARMContext>();
const size_t coor_size = n * h * w;
const size_t workspace_size = coor_size * 12 * sizeof(float);
ctx.ExtendWorkspace(workspace_size);
int32_t* coor_p = ctx.workspace_data<int>();
float* dis_p = reinterpret_cast<float*>(coor_p) + coor_size * 4;
uint32_t* bound_p = reinterpret_cast<uint32_t*>(dis_p) + coor_size * 4;
float x_max = static_cast<float>(w - 1);
float y_max = static_cast<float>(h - 1);
float32x4_t vxmax = vdupq_n_f32(x_max);
float32x4_t vymax = vdupq_n_f32(y_max);
float32x4_t vone = vdupq_n_f32(1.f);
float32x4_t vzero = vdupq_n_f32(0.f);
// compute coor, dis, bound
int i = coor_size;
for (; i > 3; i -= 4) {
float32x4x2_t xy = vld2q_f32(grid);
float32x4_t grid_x = vmulq_n_f32(vaddq_f32(xy.val[0], vone), 0.5 * x_max);
float32x4_t grid_y = vmulq_n_f32(vaddq_f32(xy.val[1], vone), 0.5 * y_max);
grid += 8;
// compute xw, we, yn, ys
int32x4x4_t vcoor;
vcoor.val[0] = vcvtq_s32_f32(grid_x);
vcoor.val[2] = vcvtq_s32_f32(grid_y);
float32x4_t vxwf = vcvtq_f32_s32(vcoor.val[0]);
float32x4_t vynf = vcvtq_f32_s32(vcoor.val[2]);
float32x4_t vxef = vaddq_f32(vxwf, vone);
float32x4_t vysf = vaddq_f32(vynf, vone);
vcoor.val[1] = vcvtq_s32_f32(vxef);
vcoor.val[3] = vcvtq_s32_f32(vysf);
vst4q_s32(coor_p, vcoor);
coor_p += 16;
// compute dw, dn ,de, ds
float32x4x4_t vdis;
vdis.val[0] = vsubq_f32(grid_x, vxwf);
vdis.val[2] = vsubq_f32(grid_y, vynf);
vdis.val[1] = vsubq_f32(vxef, grid_x);
vdis.val[3] = vsubq_f32(vysf, grid_y);
vst4q_f32(dis_p, vdis);
dis_p += 16;
// compute bound
uint32x4x4_t vbound;
uint32x4_t logic_xw =
vorrq_u32(vcltq_f32(vxwf, vzero), vcgtq_f32(vxwf, vxmax));
uint32x4_t logic_xe =
vorrq_u32(vcltq_f32(vxef, vzero), vcgtq_f32(vxef, vxmax));
uint32x4_t logic_yn =
vorrq_u32(vcltq_f32(vynf, vzero), vcgtq_f32(vynf, vymax));
uint32x4_t logic_ys =
vorrq_u32(vcltq_f32(vysf, vzero), vcgtq_f32(vysf, vymax));
vbound.val[0] = vmvnq_u32(vorrq_u32(logic_xw, logic_yn));
vbound.val[1] = vmvnq_u32(vorrq_u32(logic_xe, logic_yn));
vbound.val[2] = vmvnq_u32(vorrq_u32(logic_xw, logic_ys));
vbound.val[3] = vmvnq_u32(vorrq_u32(logic_xe, logic_ys));
vst4q_u32(bound_p, vbound);
bound_p += 16;
}
for (; i > 0; i--) {
float x = grid[0];
float y = grid[1];
float grid_x = (x + 1) * 0.5 * x_max;
float grid_y = (y + 1) * 0.5 * y_max;
grid += 2;
// compute xw, xe, yn, ys
int32_t xw = static_cast<int32_t>(floor(grid_x));
int32_t xe = xw + 1;
int32_t yn = static_cast<int32_t>(floor(grid_y));
int32_t ys = yn + 1;
*coor_p++ = xw;
*coor_p++ = xe;
*coor_p++ = yn;
*coor_p++ = ys;
// compute dw, de, dn, ds
float dw = grid_x - xw;
float de = xe - grid_x;
float dn = grid_y - yn;
float ds = ys - grid_y;
*dis_p++ = dw;
*dis_p++ = de;
*dis_p++ = dn;
*dis_p++ = ds;
// compute bound
bool logic_xw = (xw < 0.f || xw > x_max);
bool logic_xe = (xe < 0.f || xe > x_max);
bool logic_yn = (yn < 0.f || yn > y_max);
bool logic_ys = (ys < 0.f || ys > y_max);
*bound_p++ = ((logic_xw || logic_yn) ? 0 : 0xffffffff);
*bound_p++ = ((logic_xe || logic_yn) ? 0 : 0xffffffff);
*bound_p++ = ((logic_xw || logic_ys) ? 0 : 0xffffffff);
*bound_p++ = ((logic_xe || logic_ys) ? 0 : 0xffffffff);
}
size_t cube_size = c * h * w;
size_t spatial_size = h * w;
// compute output
for (int i = 0; i < n; ++i) {
const float* in_n = in + i * cube_size;
float* out_n = out + i * cube_size;
int32_t* coor_n = ctx.workspace_data<int>() + i * spatial_size * 4;
float* dis_n = reinterpret_cast<float*>(coor_n) + coor_size * 4;
uint32_t* bound_n = reinterpret_cast<uint32_t*>(dis_n) + coor_size * 4;
#pragma omp parallel for
for (int j = 0; j < c; ++j) {
int32_t* coor_ptr = coor_n;
float* dis_ptr = dis_n;
uint32_t* bound_ptr = bound_n;
const float* in_c = in_n + j * spatial_size;
float* out_c = out_n + j * spatial_size;
for (int k = 0; k < spatial_size; k++) {
int32x4_t vcoor = vld1q_s32(coor_ptr);
float32x4_t vdis = vld1q_f32(dis_ptr);
int32_t xw = vgetq_lane_s32(vcoor, 0);
int32_t xe = vgetq_lane_s32(vcoor, 1);
int32_t yn = vgetq_lane_s32(vcoor, 2);
int32_t ys = vgetq_lane_s32(vcoor, 3);
uint32x4_t vbound = vld1q_u32(bound_ptr);
float dw = vgetq_lane_f32(vdis, 0);
float de = vgetq_lane_f32(vdis, 1);
float dn = vgetq_lane_f32(vdis, 2);
float ds = vgetq_lane_f32(vdis, 3);
uint32_t wnbound = vgetq_lane_u32(vbound, 0);
uint32_t enbound = vgetq_lane_u32(vbound, 1);
uint32_t wsbound = vgetq_lane_u32(vbound, 2);
uint32_t esbound = vgetq_lane_u32(vbound, 3);
float in_wn = wnbound ? in_c[yn * w + xw] : 0.f;
float in_en = enbound ? in_c[yn * w + xe] : 0.f;
float in_ws = wsbound ? in_c[ys * w + xw] : 0.f;
float in_es = esbound ? in_c[ys * w + xe] : 0.f;
coor_ptr += 4;
dis_ptr += 4;
bound_ptr += 4;
*out_c++ =
ds * (in_wn * de + in_en * dw) + dn * (in_ws * de + in_es * dw);
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(grid_sampler,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::GridSamplerCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Grid", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class GridSamplerCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GridSamplerParam;
void PrepareForRun() override;
void Run() override;
virtual ~GridSamplerCompute() = default;
private:
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -49,6 +49,7 @@ add_operator(dropout_op basic SRCS dropout_op.cc DEPS ${op_DEPS}) ...@@ -49,6 +49,7 @@ add_operator(dropout_op basic SRCS dropout_op.cc DEPS ${op_DEPS})
add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS}) add_operator(layout_op basic SRCS layout_op.cc DEPS ${op_DEPS})
add_operator(instance_norm_op basic SRCS instance_norm_op.cc DEPS ${op_DEPS}) add_operator(instance_norm_op basic SRCS instance_norm_op.cc DEPS ${op_DEPS})
add_operator(subgraph_op basic SRCS subgraph_op.cc DEPS ${op_DEPS}) add_operator(subgraph_op basic SRCS subgraph_op.cc DEPS ${op_DEPS})
add_operator(grid_sampler_op basic SRCS grid_sampler_op.cc DEPS ${op_DEPS})
# 2.basic ops not used in basic models # 2.basic ops not used in basic models
add_operator(negative_op extra SRCS negative_op.cc DEPS ${op_DEPS}) add_operator(negative_op extra SRCS negative_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/grid_sampler_op.h"
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace operators {
bool GridSamplerOp::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.out);
CHECK_OR_FALSE(param_.grid);
auto x_dims = param_.x->dims();
auto grid_dims = param_.grid->dims();
CHECK_EQ(x_dims.size(), 4UL) << "Input must have 4 dimensions.";
CHECK_EQ(grid_dims.size(), 4UL) << "Grid must have 4 dimensions.";
CHECK_EQ(grid_dims[0], x_dims[0])
<< "Input(X) dims[0] and Input(Grid) dims[0] should be equal.";
CHECK_EQ(grid_dims[1], x_dims[2])
<< "Input(X) dims[2] and Input(Grid) dims[1] should be equal.";
CHECK_EQ(grid_dims[2], x_dims[3])
<< "Input(X) dims[3] and Input(Grid) dims[2] should be equal.";
return true;
}
bool GridSamplerOp::InferShape() const {
auto x_dims = param_.x->dims();
param_.out->Resize(x_dims);
return true;
}
bool GridSamplerOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
param_.x = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
param_.grid =
scope->FindVar(op_desc.Input("Grid").front())->GetMutable<Tensor>();
param_.out =
scope->FindVar(op_desc.Output("Output").front())->GetMutable<Tensor>();
return true;
}
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
REGISTER_LITE_OP(grid_sampler, paddle::lite::operators::GridSamplerOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class GridSamplerOp : public OpLite {
public:
GridSamplerOp() {}
explicit GridSamplerOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "grid_sampler"; }
private:
mutable GridSamplerParam param_;
};
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
...@@ -1118,6 +1118,12 @@ struct InstanceNormParam { ...@@ -1118,6 +1118,12 @@ struct InstanceNormParam {
lite::Tensor* saved_variance{}; lite::Tensor* saved_variance{};
float epsilon; float epsilon;
}; };
/// --------------------- grid sampler operators --------------------
struct GridSamplerParam {
lite::Tensor* x{};
lite::Tensor* out{};
lite::Tensor* grid{};
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
......
...@@ -15,6 +15,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH ...@@ -15,6 +15,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_compare_compute SRCS compare_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
#include "lite/tests/utils/fill_data.h"
namespace paddle {
namespace lite {
class GridSamplerComputeTest : public arena::TestCase {
protected:
// common attributes for this op.
std::string input_ = "x";
std::string output_ = "y";
std::string grid_ = "grid";
DDim dims_{{4, 5, 19, 19}};
public:
GridSamplerComputeTest(const Place& place,
const std::string& alias,
DDim dims)
: TestCase(place, alias), dims_(dims) {}
void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(input_);
auto grid = scope->FindTensor(grid_);
auto out = scope->NewTensor(output_);
CHECK(out);
out->Resize(dims_);
const float* x_data = x->data<float>();
const float* grid_data = grid->data<float>();
float* out_data = out->mutable_data<float>();
int num = x->dims()[0];
int channel = x->dims()[1];
int height = x->dims()[2];
int width = x->dims()[3];
int spatial_size = height * width;
auto inbound = [](int x, int y, float x_max, float y_max) {
if (x < 0 || x > x_max || y < 0 || y > y_max) {
return false;
}
return true;
};
for (int n = 0; n < num; ++n) {
const float* x_n = x_data + n * channel * height * width;
float* out_n = out_data + n * channel * height * width;
const float* grid_n = grid_data + n * height * width * 2;
for (int c = 0; c < channel; ++c) {
const float* x_c = x_n + c * spatial_size;
float* out_c = out_n + c * spatial_size;
for (int s = 0; s < spatial_size; ++s) {
float x = grid_n[s * 2];
float y = grid_n[s * 2 + 1];
float xwf = (x + 1.f) * 0.5 * (width - 1);
float ynf = (y + 1.f) * 0.5 * (height - 1);
int xw = floor(xwf);
int xe = xw + 1;
int yn = floor(ynf);
int ys = yn + 1;
float dw = xwf - xw;
float de = xe - xwf;
float dn = ynf - yn;
float ds = ys - ynf;
float wn = inbound(xw,
yn,
static_cast<float>(width - 1),
static_cast<float>(height - 1))
? x_c[yn * width + xw]
: 0.f;
float en = inbound(xe,
yn,
static_cast<float>(width - 1),
static_cast<float>(height - 1))
? x_c[yn * width + xe]
: 0.f;
float ws = inbound(xw,
ys,
static_cast<float>(width - 1),
static_cast<float>(height - 1))
? x_c[ys * width + xw]
: 0.f;
float es = inbound(xe,
ys,
static_cast<float>(width - 1),
static_cast<float>(height - 1))
? x_c[ys * width + xe]
: 0.f;
out_c[s] = wn * de * ds + en * dw * ds + ws * de * dn + es * dw * dn;
}
}
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("grid_sampler");
op_desc->SetInput("X", {input_});
op_desc->SetInput("Grid", {grid_});
op_desc->SetOutput("Output", {output_});
}
void PrepareData() override {
std::vector<float> din(dims_.production());
fill_data_rand(din.data(), -1.f, 1.f, dims_.production());
DDim gird_dims{{dims_[0], dims_[2], dims_[3], 2}};
std::vector<float> grid(gird_dims.production());
fill_data_rand(grid.data(), -1.f, 1.f, gird_dims.production());
SetCommonTensor(input_, dims_, din.data());
SetCommonTensor(grid_, gird_dims, grid.data());
}
};
void test_grid_sampler(Place place) {
for (auto& n : {1, 13}) {
for (auto& c : {1, 3, 8}) {
for (auto& h : {1, 3, 8, 64}) {
for (auto& w : {2, 4, 9, 63}) {
DDim dim_in({n, c, h, w});
std::unique_ptr<arena::TestCase> tester(
new GridSamplerComputeTest(place, "def", dim_in));
#ifdef LITE_WITH_ARM
auto& ctx = tester->context()->As<ARMContext>();
ctx.SetRunMode(lite_api::LITE_POWER_HIGH, 1);
#endif
arena::Arena arena(std::move(tester), place, 6e-5);
LOG(INFO) << "run n: " << n << ", c: " << c << ", h: " << h
<< ", w: " << w;
if (!arena.TestPrecision()) {
LOG(ERROR) << "No Pass!!";
return;
}
// if you want to test this op performance, uncomment the following
// line
// arena.TestPerformance();
}
}
}
}
}
TEST(GridSampler, precision) {
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_grid_sampler(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册