未验证 提交 3fb309c3 编写于 作者: C cc 提交者: GitHub

Support lac model, fix bug for argmax lod_reset (#2967)

* Support lac model, fix bug for argmax lod_reset, test=develop
上级 c7dd23b2
...@@ -53,7 +53,7 @@ void argmax_func(const lite::Tensor *input, ...@@ -53,7 +53,7 @@ void argmax_func(const lite::Tensor *input,
std::greater<std::pair<float, int>>()); std::greater<std::pair<float, int>>());
// out // out
float *out_ptr = output->mutable_data<float>() + n * out_channel + k; int64_t *out_ptr = output->mutable_data<int64_t>() + n * out_channel + k;
*out_ptr = vec[0].second; *out_ptr = vec[0].second;
} }
} }
......
...@@ -246,6 +246,8 @@ class Arena { ...@@ -246,6 +246,8 @@ class Arena {
return tester_->CheckPrecision<int8_t>(var_name, abs_error_); return tester_->CheckPrecision<int8_t>(var_name, abs_error_);
case PRECISION(kInt32): case PRECISION(kInt32):
return tester_->CheckPrecision<int32_t>(var_name, abs_error_); return tester_->CheckPrecision<int32_t>(var_name, abs_error_);
case PRECISION(kInt64):
return tester_->CheckPrecision<int64_t>(var_name, abs_error_);
case PRECISION(kBool): case PRECISION(kBool):
return tester_->CheckPrecision<bool>(var_name, abs_error_); return tester_->CheckPrecision<bool>(var_name, abs_error_);
default: default:
......
...@@ -30,6 +30,9 @@ void ArgmaxCompute::Run() { ...@@ -30,6 +30,9 @@ void ArgmaxCompute::Run() {
lite::Tensor* input = param.X; lite::Tensor* input = param.X;
lite::Tensor* output = param.Out; lite::Tensor* output = param.Out;
int axis = param.Axis; int axis = param.Axis;
if (axis < 0) {
axis += input->dims().size();
}
lite::arm::math::argmax_func(input, axis, output); lite::arm::math::argmax_func(input, axis, output);
return; return;
...@@ -47,5 +50,5 @@ REGISTER_LITE_KERNEL(arg_max, ...@@ -47,5 +50,5 @@ REGISTER_LITE_KERNEL(arg_max,
paddle::lite::kernels::arm::ArgmaxCompute, paddle::lite::kernels::arm::ArgmaxCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.Finalize(); .Finalize();
...@@ -33,7 +33,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) { ...@@ -33,7 +33,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
int axis = param.Axis; int axis = param.Axis;
auto x_data = x->data<dtype>(); auto x_data = x->data<dtype>();
auto output_data = output->mutable_data<dtype>(); auto output_data = output->mutable_data<int64_t>();
DDim x_dims = x->dims(); DDim x_dims = x->dims();
DDim output_dims = output->dims(); DDim output_dims = output->dims();
...@@ -59,7 +59,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) { ...@@ -59,7 +59,7 @@ void argmax_compute_ref(const operators::ArgmaxParam& param) {
std::greater<std::pair<dtype, int>>()); std::greater<std::pair<dtype, int>>());
// out // out
dtype* out_ptr = output_data + n * out_channel + k; auto* out_ptr = output_data + n * out_channel + k;
*out_ptr = vec[0].second; *out_ptr = vec[0].second;
} }
} }
...@@ -115,12 +115,12 @@ TEST(argmax_arm, compute) { ...@@ -115,12 +115,12 @@ TEST(argmax_arm, compute) {
param.Axis = axis; param.Axis = axis;
argmaxOp.SetParam(param); argmaxOp.SetParam(param);
argmaxOp.Launch(); argmaxOp.Launch();
auto* output_data = output.mutable_data<float>(); auto* output_data = output.mutable_data<int64_t>();
// obtain output_ref_data // obtain output_ref_data
param.Out = &output_ref; param.Out = &output_ref;
argmax_compute_ref<float>(param); argmax_compute_ref<float>(param);
auto* output_ref_data = output_ref.mutable_data<float>(); auto* output_ref_data = output_ref.mutable_data<int64_t>();
// compare // compare
for (int i = 0; i < output.dims().production(); i++) { for (int i = 0; i < output.dims().production(); i++) {
......
...@@ -24,9 +24,7 @@ void LodResetCompute::PrepareForRun() {} ...@@ -24,9 +24,7 @@ void LodResetCompute::PrepareForRun() {}
void LodResetCompute::Run() { void LodResetCompute::Run() {
auto& ctx = this->ctx_->template As<ARMContext>(); auto& ctx = this->ctx_->template As<ARMContext>();
auto& param = this->Param<operators::LodResetParam>(); auto& param = this->Param<operators::LodResetParam>();
const auto* x_data = param.X->data<float>(); param.Out->CopyDataFrom(*param.X);
auto* o_data = param.Out->mutable_data<float>();
memcpy(o_data, x_data, sizeof(float) * param.X->numel());
auto lod = param.Out->mutable_lod(); auto lod = param.Out->mutable_lod();
if (param.Y) { if (param.Y) {
if (param.Y->lod().size()) { if (param.Y->lod().size()) {
...@@ -54,11 +52,11 @@ void LodResetCompute::Run() { ...@@ -54,11 +52,11 @@ void LodResetCompute::Run() {
REGISTER_LITE_KERNEL(lod_reset, REGISTER_LITE_KERNEL(lod_reset,
kARM, kARM,
kFloat, kAny,
kNCHW, kNCHW,
paddle::lite::kernels::arm::LodResetCompute, paddle::lite::kernels::arm::LodResetCompute,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize(); .Finalize();
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
namespace arm { namespace arm {
class LodResetCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { class LodResetCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public: public:
using param_t = operators::LodResetParam; using param_t = operators::LodResetParam;
......
...@@ -24,7 +24,8 @@ namespace operators { ...@@ -24,7 +24,8 @@ namespace operators {
bool ArgmaxOpLite::CheckShape() const { bool ArgmaxOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X); CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
CHECK_OR_FALSE(param_.Axis < (param_.X)->dims().size()); CHECK_OR_FALSE(param_.Axis < static_cast<int>((param_.X)->dims().size()));
CHECK_OR_FALSE(param_.Axis >= static_cast<int>(-(param_.X)->dims().size()));
return true; return true;
} }
...@@ -32,7 +33,9 @@ bool ArgmaxOpLite::InferShape() const { ...@@ -32,7 +33,9 @@ bool ArgmaxOpLite::InferShape() const {
auto x_dims = param_.X->dims(); auto x_dims = param_.X->dims();
int x_rank = x_dims.size(); int x_rank = x_dims.size();
int axis = param_.Axis; int axis = param_.Axis;
if (axis < 0) axis += x_rank; if (axis < 0) {
axis += x_rank;
}
std::vector<int64_t> out_dims; std::vector<int64_t> out_dims;
for (int64_t i = 0; i < axis; i++) out_dims.push_back(x_dims[i]); for (int64_t i = 0; i < axis; i++) out_dims.push_back(x_dims[i]);
......
...@@ -48,7 +48,7 @@ class ArgmaxComputeTester : public arena::TestCase { ...@@ -48,7 +48,7 @@ class ArgmaxComputeTester : public arena::TestCase {
output_shape.erase(output_shape.begin() + axis_); output_shape.erase(output_shape.begin() + axis_);
DDim output_dims(output_shape); DDim output_dims(output_shape);
out->Resize(output_dims); out->Resize(output_dims);
auto* output_data = out->mutable_data<float>(); auto* output_data = out->mutable_data<int64_t>();
auto* x = scope->FindTensor(input_); auto* x = scope->FindTensor(input_);
const auto* x_data = x->data<float>(); const auto* x_data = x->data<float>();
...@@ -75,7 +75,7 @@ class ArgmaxComputeTester : public arena::TestCase { ...@@ -75,7 +75,7 @@ class ArgmaxComputeTester : public arena::TestCase {
std::greater<std::pair<float, int>>()); std::greater<std::pair<float, int>>());
// out // out
float* out_ptr = output_data + n * out_channel + k; auto* out_ptr = output_data + n * out_channel + k;
*out_ptr = vec[0].second; *out_ptr = vec[0].second;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册