From 628aa0df397d6fd7750608b221c261765f1439cb Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Sat, 7 Mar 2020 09:32:47 +0800 Subject: [PATCH] get precision from real tensor or tensor_array of base_scope in arena/framwork (#3092) * get the precision from the real tensor or tensor_array of base_scope in arena/framwork * register assign and assign_value to kAny --- lite/core/arena/framework.cc | 109 ++++++++++++++ lite/core/arena/framework.h | 137 ++---------------- lite/kernels/arm/assign_compute.cc | 15 +- lite/kernels/arm/assign_compute.h | 4 +- lite/kernels/arm/assign_value_compute.cc | 4 +- lite/kernels/arm/assign_value_compute.h | 2 +- lite/tests/kernels/assign_compute_test.cc | 11 +- .../kernels/assign_value_compute_test.cc | 9 +- lite/tests/kernels/cast_compute_test.cc | 29 ---- ...l_constant_batch_size_like_compute_test.cc | 2 - .../kernels/fill_constant_compute_test.cc | 1 - .../kernels/read_from_array_compute_test.cc | 2 - lite/tests/kernels/unsqueeze_compute_test.cc | 2 - .../kernels/write_to_array_compute_test.cc | 2 - 14 files changed, 142 insertions(+), 187 deletions(-) diff --git a/lite/core/arena/framework.cc b/lite/core/arena/framework.cc index 66af1b9b03..ac822e3764 100644 --- a/lite/core/arena/framework.cc +++ b/lite/core/arena/framework.cc @@ -123,6 +123,115 @@ void TestCase::PrepareInputsForInstruction() { } } +template +bool TestCase::CheckTensorPrecision(const Tensor* a_tensor, + const Tensor* b_tensor, + float abs_error) { + CHECK(a_tensor); + CHECK(b_tensor); + + CHECK(ShapeEquals(a_tensor->dims(), b_tensor->dims())); + + CHECK(a_tensor->lod() == b_tensor->lod()) << "lod not match"; + + // The baseline should output in host devices. + CHECK(b_tensor->target() == TARGET(kHost) || + b_tensor->target() == TARGET(kX86) || + b_tensor->target() == TARGET(kARM)); + + const T* a_data{}; + switch (a_tensor->target()) { + case TARGET(kX86): + case TARGET(kHost): + case TARGET(kARM): + a_data = static_cast(a_tensor->raw_data()); + break; + + default: + // Before compare, need to copy data from `target` device to host. + LOG(FATAL) << "Not supported"; + } + + CHECK(a_data); + + const T* b_data = static_cast(b_tensor->raw_data()); + + bool success = true; + for (int i = 0; i < a_tensor->dims().production(); i++) { + EXPECT_NEAR(a_data[i], b_data[i], abs_error); + if (fabsf(a_data[i] - b_data[i]) > abs_error) { + success = false; + } + } + return success; +} + +bool TestCase::CheckPrecision(const Tensor* a_tensor, + const Tensor* b_tensor, + float abs_error, + PrecisionType precision_type) { + PrecisionType precision_type_t = precision_type; + if (precision_type == PRECISION(kAny)) { + precision_type_t = b_tensor->precision(); + } + CHECK(precision_type_t == b_tensor->precision()) + << "arg precision type and base tensor precision type are not matched! " + "arg precision type is: " + << PrecisionToStr(precision_type) << ", base tensor precision type is: " + << PrecisionToStr(b_tensor->precision()); + CHECK(a_tensor->precision() == b_tensor->precision()) + << "real tensor precision type and base tensor precision type are not " + "matched! real tensor precision type is: " + << PrecisionToStr(a_tensor->precision()) + << ", base tensor precision type is: " + << PrecisionToStr(b_tensor->precision()); + switch (precision_type_t) { + case PRECISION(kFloat): + return CheckTensorPrecision(a_tensor, b_tensor, abs_error); + case PRECISION(kInt8): + return CheckTensorPrecision(a_tensor, b_tensor, abs_error); + case PRECISION(kInt32): + return CheckTensorPrecision(a_tensor, b_tensor, abs_error); + case PRECISION(kInt64): + return CheckTensorPrecision(a_tensor, b_tensor, abs_error); + case PRECISION(kBool): + return CheckTensorPrecision(a_tensor, b_tensor, abs_error); + default: + LOG(FATAL) << "not support type: " << PrecisionToStr(precision_type); + return false; + } +} + +bool TestCase::CheckPrecision(const std::string& var_name, + float abs_error, + PrecisionType precision_type) { + bool success = true; + if (inst_scope_->FindVar(var_name)->IsType()) { + auto a_tensor = inst_scope_->FindTensor(var_name); + auto b_tensor = base_scope_->FindTensor(var_name); + success = success && + CheckPrecision(a_tensor, b_tensor, abs_error, precision_type); + } else if (inst_scope_->FindVar(var_name)->IsType>()) { + auto a_tensor_array = + inst_scope_->FindVar(var_name)->GetMutable>(); + auto b_tensor_array = + base_scope_->FindVar(var_name)->GetMutable>(); + CHECK_EQ(a_tensor_array->size(), b_tensor_array->size()); + for (int i = 0; i < a_tensor_array->size(); i++) { + Tensor* a_tensor = &(a_tensor_array->at(i)); + Tensor* b_tensor = &(b_tensor_array->at(i)); + if (a_tensor->dims().size() == 0 && b_tensor->dims().size() == 0) { + continue; + } + success = success && + CheckPrecision(a_tensor, b_tensor, abs_error, precision_type); + } + } else { + LOG(FATAL) << "unsupported var type"; + } + return success; +} + TestCase::~TestCase() { if (op_desc_->Type() == "subgraph") { // Release the subblock desc of Subgraph op diff --git a/lite/core/arena/framework.h b/lite/core/arena/framework.h index 3c90ef325c..7050355fbf 100644 --- a/lite/core/arena/framework.h +++ b/lite/core/arena/framework.h @@ -66,19 +66,24 @@ class TestCase { /// output. virtual void RunBaseline(Scope* scope) = 0; - // checkout the precision of the two tensors. b_tensor is from the baseline + // checkout the precision of the two tensors with type T. b_tensor is baseline template bool CheckTensorPrecision(const Tensor* a_tensor, const Tensor* b_tensor, float abs_error); + // checkout the precision of the two tensors. b_tensor is baseline + bool CheckPrecision(const Tensor* a_tensor, + const Tensor* b_tensor, + float abs_error, + PrecisionType precision_type); + /// Check the precision of the output variables. It will compare the same - /// tensor - /// (or all tensors of tensor_array) in two scopes, one of the instruction - /// execution, - /// and the other for the baseline. - template - bool CheckPrecision(const std::string& var_name, float abs_error); + /// tensor (or all tensors of the tensor_array) in two scopes, one of the + /// instruction execution, and the other for the baseline. + bool CheckPrecision(const std::string& var_name, + float abs_error, + PrecisionType precision_type); const cpp::OpDesc& op_desc() { return *op_desc_; } @@ -86,20 +91,6 @@ class TestCase { // kernel registry. void CheckKernelConsistWithDefinition() {} - // Get the real precision of the output for check precision. When the declare - // precision obtained from the kernel is any, we should set the precision of - // the output in test case. - bool GetPrecisonType(const std::string& var_name, - PrecisionType* precision_type) { - auto res = precision_type_map_.find(var_name); - if (res == precision_type_map_.end()) { - return false; - } else { - *precision_type = precision_type_map_.at(var_name); - return true; - } - } - Scope& scope() { return *scope_; } Scope* baseline_scope() { return base_scope_; } @@ -159,19 +150,6 @@ class TestCase { // Prepare for the operator. virtual void PrepareOpDesc(cpp::OpDesc* op_desc) = 0; - // Set the real precision of the output for check precision. When the declare - // precision obtained from the kernel is any, we should set the precision of - // the output in test case. - void SetPrecisionType(const std::string& var_name, - const PrecisionType& precision_type) { - auto res = precision_type_map_.find(var_name); - if (res == precision_type_map_.end()) { - precision_type_map_.insert({var_name, precision_type}); - } else { - precision_type_map_.at(var_name) = precision_type; - } - } - public: const Instruction& instruction() { return *instruction_; } @@ -215,7 +193,6 @@ class TestCase { Scope* base_scope_{}; std::unique_ptr op_desc_; std::unique_ptr instruction_; - std::unordered_map precision_type_map_; }; class Arena { @@ -272,24 +249,7 @@ class Arena { const Type* type = tester_->instruction().kernel()->GetOutputDeclType(arg_name); auto precision_type = type->precision(); - if (precision_type == PRECISION(kAny)) { - CHECK(tester_->GetPrecisonType(var_name, &precision_type)); - } - switch (precision_type) { - case PRECISION(kFloat): - return tester_->CheckPrecision(var_name, abs_error_); - case PRECISION(kInt8): - return tester_->CheckPrecision(var_name, abs_error_); - case PRECISION(kInt32): - return tester_->CheckPrecision(var_name, abs_error_); - case PRECISION(kInt64): - return tester_->CheckPrecision(var_name, abs_error_); - case PRECISION(kBool): - return tester_->CheckPrecision(var_name, abs_error_); - default: - LOG(FATAL) << "not support type " << PrecisionToStr(type->precision()); - return false; - } + return tester_->CheckPrecision(var_name, abs_error_, precision_type); } private: @@ -298,77 +258,6 @@ class Arena { float abs_error_; }; -template -bool TestCase::CheckTensorPrecision(const Tensor* a_tensor, - const Tensor* b_tensor, - float abs_error) { - CHECK(a_tensor); - CHECK(b_tensor); - - CHECK(ShapeEquals(a_tensor->dims(), b_tensor->dims())); - - CHECK(a_tensor->lod() == b_tensor->lod()) << "lod not match"; - - // The baseline should output in host devices. - CHECK(b_tensor->target() == TARGET(kHost) || - b_tensor->target() == TARGET(kX86) || - b_tensor->target() == TARGET(kARM)); - - const T* a_data{}; - switch (a_tensor->target()) { - case TARGET(kX86): - case TARGET(kHost): - case TARGET(kARM): - a_data = static_cast(a_tensor->raw_data()); - break; - - default: - // Before compare, need to copy data from `target` device to host. - LOG(FATAL) << "Not supported"; - } - - CHECK(a_data); - - const T* b_data = static_cast(b_tensor->raw_data()); - - bool success = true; - for (int i = 0; i < a_tensor->dims().production(); i++) { - EXPECT_NEAR(a_data[i], b_data[i], abs_error); - if (fabsf(a_data[i] - b_data[i]) > abs_error) { - success = false; - } - } - return success; -} - -template -bool TestCase::CheckPrecision(const std::string& var_name, float abs_error) { - bool success = true; - if (inst_scope_->FindVar(var_name)->IsType()) { - auto a_tensor = inst_scope_->FindTensor(var_name); - auto b_tensor = base_scope_->FindTensor(var_name); - success = success && CheckTensorPrecision(a_tensor, b_tensor, abs_error); - } else if (inst_scope_->FindVar(var_name)->IsType>()) { - auto a_tensor_array = - inst_scope_->FindVar(var_name)->GetMutable>(); - auto b_tensor_array = - base_scope_->FindVar(var_name)->GetMutable>(); - CHECK_EQ(a_tensor_array->size(), b_tensor_array->size()); - for (int i = 0; i < a_tensor_array->size(); i++) { - Tensor* a_tensor = &(a_tensor_array->at(i)); - Tensor* b_tensor = &(b_tensor_array->at(i)); - if (a_tensor->dims().size() == 0 && b_tensor->dims().size() == 0) { - continue; - } - success = - success && CheckTensorPrecision(a_tensor, b_tensor, abs_error); - } - } else { - LOG(FATAL) << "unsupported var type"; - } - return success; -} - } // namespace arena } // namespace lite } // namespace paddle diff --git a/lite/kernels/arm/assign_compute.cc b/lite/kernels/arm/assign_compute.cc index b0a5529c36..8398634bb3 100644 --- a/lite/kernels/arm/assign_compute.cc +++ b/lite/kernels/arm/assign_compute.cc @@ -23,16 +23,9 @@ namespace lite { namespace kernels { namespace arm { -void AssignCompute::PrepareForRun() { - // CHECK_OR_FALSE(param_t.Out); -} - void AssignCompute::Run() { - // LOG(INFO) << "into kernel compute run"; auto& param = Param(); - const lite::Tensor* input = param.X; - lite::Tensor* output = param.Out; - output->CopyDataFrom(*input); + param.Out->CopyDataFrom(*param.X); } } // namespace arm @@ -41,7 +34,7 @@ void AssignCompute::Run() { } // namespace paddle REGISTER_LITE_KERNEL( - assign, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::AssignCompute, def) - .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + assign, kARM, kAny, kNCHW, paddle::lite::kernels::arm::AssignCompute, def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/assign_compute.h b/lite/kernels/arm/assign_compute.h index 3f0dd8e281..e144486b59 100644 --- a/lite/kernels/arm/assign_compute.h +++ b/lite/kernels/arm/assign_compute.h @@ -22,10 +22,10 @@ namespace lite { namespace kernels { namespace arm { -class AssignCompute : public KernelLite { +class AssignCompute : public KernelLite { public: using param_t = operators::AssignParam; - void PrepareForRun() override; + void Run() override; virtual ~AssignCompute() = default; diff --git a/lite/kernels/arm/assign_value_compute.cc b/lite/kernels/arm/assign_value_compute.cc index 45f28ba363..1d097e336f 100644 --- a/lite/kernels/arm/assign_value_compute.cc +++ b/lite/kernels/arm/assign_value_compute.cc @@ -58,9 +58,9 @@ void AssignValueCompute::Run() { REGISTER_LITE_KERNEL(assign_value, kARM, - kFloat, + kAny, kNCHW, paddle::lite::kernels::arm::AssignValueCompute, def) - .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .Finalize(); diff --git a/lite/kernels/arm/assign_value_compute.h b/lite/kernels/arm/assign_value_compute.h index f0c33f865b..32b1fb41ab 100644 --- a/lite/kernels/arm/assign_value_compute.h +++ b/lite/kernels/arm/assign_value_compute.h @@ -22,7 +22,7 @@ namespace lite { namespace kernels { namespace arm { -class AssignValueCompute : public KernelLite { +class AssignValueCompute : public KernelLite { public: using param_t = operators::AssignValueParam; diff --git a/lite/tests/kernels/assign_compute_test.cc b/lite/tests/kernels/assign_compute_test.cc index 92f885f8da..eec875608a 100644 --- a/lite/tests/kernels/assign_compute_test.cc +++ b/lite/tests/kernels/assign_compute_test.cc @@ -67,13 +67,14 @@ void TestAssign(const Place& place) { } TEST(Assign, precision) { -#ifdef LITE_WITH_X86 - Place place(TARGET(kX86)); -#endif + Place place; #ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); - TestAssign(place); + place = {TARGET(kARM), PRECISION(kAny)}; +#else + return; #endif + + TestAssign(place); } } // namespace lite diff --git a/lite/tests/kernels/assign_value_compute_test.cc b/lite/tests/kernels/assign_value_compute_test.cc index 96959e507d..c3b190f8b8 100644 --- a/lite/tests/kernels/assign_value_compute_test.cc +++ b/lite/tests/kernels/assign_value_compute_test.cc @@ -95,10 +95,12 @@ class AssignValueComputeTester : public arena::TestCase { }; TEST(AssignValue, precision) { - LOG(INFO) << "test argmax op"; + Place place; #ifdef LITE_WITH_ARM - LOG(INFO) << "test argmax arm"; - Place place(TARGET(kARM)); + place = {TARGET(kARM), PRECISION(kAny)}; +#else + return; +#endif for (int dtype : {2, 5}) { for (int n : {1}) { @@ -114,7 +116,6 @@ TEST(AssignValue, precision) { } } } -#endif } } // namespace lite diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index e3b27ce627..3a57a8ab44 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -119,35 +119,6 @@ class CastComputeTester : public arena::TestCase { LOG(FATAL) << "unsupported data type: " << in_dtype_; break; } - - PrecisionType out_ptype; - switch (out_dtype_) { - case 0: - out_ptype = PRECISION(kBool); - break; - case 21: - out_ptype = PRECISION(kInt8); - break; - case 1: - out_ptype = PRECISION(kInt16); - break; - case 2: - out_ptype = PRECISION(kInt32); - break; - case 3: - out_ptype = PRECISION(kInt64); - break; - case 4: - out_ptype = PRECISION(kFP16); - break; - case 5: - out_ptype = PRECISION(kFloat); - break; - default: - LOG(FATAL) << "unsupported data type: " << out_dtype_; - break; - } - SetPrecisionType(out_, out_ptype); } }; diff --git a/lite/tests/kernels/fill_constant_batch_size_like_compute_test.cc b/lite/tests/kernels/fill_constant_batch_size_like_compute_test.cc index e225d81ad0..734b92c186 100644 --- a/lite/tests/kernels/fill_constant_batch_size_like_compute_test.cc +++ b/lite/tests/kernels/fill_constant_batch_size_like_compute_test.cc @@ -86,8 +86,6 @@ class FillConstantBatchSizeLikeComputeTester : public arena::TestCase { std::vector din(in_dims_.production()); fill_data_rand(din.data(), -1.f, 1.f, in_dims_.production()); SetCommonTensor(input_, in_dims_, din.data(), in_lod_); - - SetPrecisionType(out_, PRECISION(kFloat)); } }; diff --git a/lite/tests/kernels/fill_constant_compute_test.cc b/lite/tests/kernels/fill_constant_compute_test.cc index ef5a4fa778..047bbf13da 100644 --- a/lite/tests/kernels/fill_constant_compute_test.cc +++ b/lite/tests/kernels/fill_constant_compute_test.cc @@ -109,7 +109,6 @@ class FillConstantComputeTester : public arena::TestCase { SetCommonTensor(shape_tensor_list_[i], DDim({1}), dshape_tensor.data()); } } - SetPrecisionType(out_, PRECISION(kFloat)); } }; diff --git a/lite/tests/kernels/read_from_array_compute_test.cc b/lite/tests/kernels/read_from_array_compute_test.cc index 73eb6cb463..2e6bed3b15 100644 --- a/lite/tests/kernels/read_from_array_compute_test.cc +++ b/lite/tests/kernels/read_from_array_compute_test.cc @@ -68,8 +68,6 @@ class ReadFromArrayComputeTester : public arena::TestCase { std::vector didn(1); didn[0] = id_; SetCommonTensor(idn_, DDim{{1}}, didn.data()); - - SetPrecisionType(out_, PRECISION(kFloat)); } }; diff --git a/lite/tests/kernels/unsqueeze_compute_test.cc b/lite/tests/kernels/unsqueeze_compute_test.cc index d8ec2b01f7..aba7bed4f1 100644 --- a/lite/tests/kernels/unsqueeze_compute_test.cc +++ b/lite/tests/kernels/unsqueeze_compute_test.cc @@ -107,7 +107,6 @@ class UnsqueezeComputeTester : public arena::TestCase { } void PrepareData() override { - SetPrecisionType(out_, PRECISION(kFloat)); std::vector in_data(dims_.production()); for (int i = 0; i < dims_.production(); ++i) { in_data[i] = i; @@ -214,7 +213,6 @@ class Unsqueeze2ComputeTester : public arena::TestCase { } void PrepareData() override { - SetPrecisionType(out_, PRECISION(kFloat)); std::vector in_data(dims_.production()); for (int i = 0; i < dims_.production(); ++i) { in_data[i] = i; diff --git a/lite/tests/kernels/write_to_array_compute_test.cc b/lite/tests/kernels/write_to_array_compute_test.cc index 7b409258cb..7d3410a980 100644 --- a/lite/tests/kernels/write_to_array_compute_test.cc +++ b/lite/tests/kernels/write_to_array_compute_test.cc @@ -66,8 +66,6 @@ class WriteToArrayComputeTester : public arena::TestCase { std::vector didn(1); didn[0] = id_; SetCommonTensor(idn_, DDim{{1}}, didn.data()); - - SetPrecisionType(out_, PRECISION(kFloat)); } }; -- GitLab