提交 aff2b94a 编写于 作者: C cc 提交者: GitHub

Support lac model, fix bug for argmax lod_reset (#2967)

* Support lac model, fix bug for argmax lod_reset, test=develop
上级 b9cdf544
......@@ -53,7 +53,7 @@ void argmax_func(const lite::Tensor *input,
std::greater<std::pair<float, int>>());
// out
float *out_ptr = output->mutable_data<float>() + n * out_channel + k;
int64_t *out_ptr = output->mutable_data<int64_t>() + n * out_channel + k;
*out_ptr = vec[0].second;
}
}
......
......@@ -246,6 +246,8 @@ class Arena {
return tester_->CheckPrecision<int8_t>(var_name, abs_error_);
case PRECISION(kInt32):
return tester_->CheckPrecision<int32_t>(var_name, abs_error_);
case PRECISION(kInt64):
return tester_->CheckPrecision<int64_t>(var_name, abs_error_);
case PRECISION(kBool):
return tester_->CheckPrecision<bool>(var_name, abs_error_);
default:
......
......@@ -30,6 +30,9 @@ void ArgmaxCompute::Run() {
lite::Tensor* input = param.X;
lite::Tensor* output = param.Out;
int axis = param.Axis;
if (axis < 0) {
axis += input->dims().size();
}
lite::arm::math::argmax_func(input, axis, output);
return;
......@@ -47,5 +50,5 @@ REGISTER_LITE_KERNEL(arg_max,
paddle::lite::kernels::arm::ArgmaxCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize();
......@@ -33,7 +33,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
int axis = param.Axis;
auto x_data = x->data<dtype>();
auto output_data = output->mutable_data<dtype>();
auto output_data = output->mutable_data<int64_t>();
DDim x_dims = x->dims();
DDim output_dims = output->dims();
......@@ -59,7 +59,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
std::greater<std::pair<dtype, int>>());
// out
dtype* out_ptr = output_data + n * out_channel + k;
auto* out_ptr = output_data + n * out_channel + k;
*out_ptr = vec[0].second;
}
}
......@@ -115,12 +115,12 @@ TEST(argmax_arm, compute) {
param.Axis = axis;
argmaxOp.SetParam(param);
argmaxOp.Launch();
auto* output_data = output.mutable_data<float>();
auto* output_data = output.mutable_data<int64_t>();
// obtain output_ref_data
param.Out = &output_ref;
argmax_compute_ref<float>(param);
auto* output_ref_data = output_ref.mutable_data<float>();
auto* output_ref_data = output_ref.mutable_data<int64_t>();
// compare
for (int i = 0; i < output.dims().production(); i++) {
......
......@@ -24,9 +24,7 @@ void LodResetCompute::PrepareForRun() {}
void LodResetCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::LodResetParam>();
const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel());
param.Out->CopyDataFrom(*param.X);
auto lod = param.Out->mutable_lod();
if (param.Y) {
if (param.Y->lod().size()) {
......@@ -54,11 +52,11 @@ void LodResetCompute::Run() {
REGISTER_LITE_KERNEL(lod_reset,
kARM,
kFloat,
kAny,
kNCHW,
paddle::lite::kernels::arm::LodResetCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
......@@ -22,7 +22,7 @@ namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class LodResetCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class LodResetCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::LodResetParam;
......
......@@ -24,7 +24,8 @@ namespace operators {
bool ArgmaxOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.Axis < (param_.X)->dims().size());
CHECK_OR_FALSE(param_.Axis < static_cast<int>((param_.X)->dims().size()));
CHECK_OR_FALSE(param_.Axis >= static_cast<int>(-(param_.X)->dims().size()));
return true;
}
......@@ -32,7 +33,9 @@ bool ArgmaxOpLite::InferShape() const {
auto x_dims = param_.X->dims();
int x_rank = x_dims.size();
int axis = param_.Axis;
if (axis < 0) axis += x_rank;
if (axis < 0) {
axis += x_rank;
}
std::vector<int64_t> out_dims;
for (int64_t i = 0; i < axis; i++) out_dims.push_back(x_dims[i]);
......
......@@ -48,7 +48,7 @@ class ArgmaxComputeTester : public arena::TestCase {
output_shape.erase(output_shape.begin() + axis_);
DDim output_dims(output_shape);
out->Resize(output_dims);
auto* output_data = out->mutable_data<float>();
auto* output_data = out->mutable_data<int64_t>();
auto* x = scope->FindTensor(input_);
const auto* x_data = x->data<float>();
......@@ -75,7 +75,7 @@ class ArgmaxComputeTester : public arena::TestCase {
std::greater<std::pair<float, int>>());
// out
float* out_ptr = output_data + n * out_channel + k;
auto* out_ptr = output_data + n * out_channel + k;
*out_ptr = vec[0].second;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册