diff --git a/lite/kernels/xpu/bridges/CMakeLists.txt b/lite/kernels/xpu/bridges/CMakeLists.txt index 4e166a1c126db4a4c33d517e6ab78c970b34b300..29cb83b2b853d4953bfbe7faca8633f2789e1d50 100644 --- a/lite/kernels/xpu/bridges/CMakeLists.txt +++ b/lite/kernels/xpu/bridges/CMakeLists.txt @@ -24,6 +24,7 @@ lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_sub lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps}) +lite_cc_library(subgraph_bridge_cast_op_xpu SRCS cast_op.cc DEPS ${xpu_subgraph_bridge_deps}) set(xpu_subgraph_bridges subgraph_bridge_registry @@ -46,6 +47,7 @@ set(xpu_subgraph_bridges subgraph_bridge_layer_norm_op_xpu subgraph_bridge_dropout_op_xpu subgraph_bridge_matmul_op_xpu + subgraph_bridge_cast_op_xpu CACHE INTERNAL "xpu_subgraph_bridges") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") diff --git a/lite/kernels/xpu/bridges/cast_op.cc b/lite/kernels/xpu/bridges/cast_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b56abcd61a1459391b7520cb5d0b4c17f901f40 --- /dev/null +++ b/lite/kernels/xpu/bridges/cast_op.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/npu/bridges/registry.h" +#include "lite/kernels/xpu/bridges/graph.h" +#include "lite/kernels/xpu/bridges/utility.h" + +namespace paddle { +namespace lite { +namespace subgraph { +namespace xpu { + +bool CvtDtype(int dtype, PrecisionType* ptype) { + switch (dtype) { + case 21: + *ptype = PRECISION(kInt8); + break; + case 1: + *ptype = PRECISION(kInt16); + break; + case 2: + *ptype = PRECISION(kInt32); + break; + case 3: + *ptype = PRECISION(kInt64); + break; + case 5: + *ptype = PRECISION(kFloat); + break; + default: + LOG(WARNING) << "[XPU] unsupported date type: " << dtype; + return false; + } + return true; +} + +int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) { + CHECK(ctx != nullptr); + CHECK(op != nullptr); + auto graph = static_cast(ctx); + auto op_info = op->op_info(); + auto op_type = op_info->Type(); + auto scope = op->scope(); + VLOG(3) << "[XPU] Converting " + op_type + "..."; + + // Get input and output vars and op attributes + auto x_name = op_info->Input("X").front(); + auto x = scope->FindMutableTensor(x_name); + auto out_name = op_info->Output("Out").front(); + + // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; + // SIZE_T = 19;UINT8 = 20;INT8 = 21; + int in_dtype = op_info->GetAttr("in_dtype"); + PrecisionType in_ptype; + if (!CvtDtype(in_dtype, &in_ptype)) { + return FAILED; + } + + int out_dtype = op_info->GetAttr("out_dtype"); + PrecisionType out_ptype; + if (!CvtDtype(out_dtype, &out_ptype)) { + return FAILED; + } + + // X node + std::shared_ptr x_node = nullptr; + if (graph->Has(x_name)) { + x_node = graph->Get(x_name); + } else { + x_node = graph->Add(x_name, *x, in_ptype); + } + + // Cast node + graph->Add( + out_name, + graph->builder_.CreateCast(*x_node->data(), CvtPrecisionType(out_ptype))); + + return SUCCESS; +} + +} // namespace xpu +} // namespace subgraph +} // namespace lite +} // namespace paddle + +REGISTER_SUBGRAPH_BRIDGE(cast, + kXPU, + paddle::lite::subgraph::xpu::CastConverter); diff --git a/lite/kernels/xpu/bridges/paddle_use_bridges.h b/lite/kernels/xpu/bridges/paddle_use_bridges.h index bed88034ae8c00cf2de4e747234c49283cc18c68..0c7886c5b2b431db7ba97d8557fb6a49750bd468 100644 --- a/lite/kernels/xpu/bridges/paddle_use_bridges.h +++ b/lite/kernels/xpu/bridges/paddle_use_bridges.h @@ -36,3 +36,4 @@ USE_SUBGRAPH_BRIDGE(layer_norm, kXPU); USE_SUBGRAPH_BRIDGE(gelu, kXPU); USE_SUBGRAPH_BRIDGE(dropout, kXPU); USE_SUBGRAPH_BRIDGE(matmul, kXPU); +USE_SUBGRAPH_BRIDGE(cast, kXPU); diff --git a/lite/tests/kernels/CMakeLists.txt b/lite/tests/kernels/CMakeLists.txt index 8ee0255f2b90cc4c93aa94ab9d28d408eaddad11..1874eafd95f3efec291465221db19fd7b4815a4b 100644 --- a/lite/tests/kernels/CMakeLists.txt +++ b/lite/tests/kernels/CMakeLists.txt @@ -13,7 +13,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) - lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) + lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) diff --git a/lite/tests/kernels/cast_compute_test.cc b/lite/tests/kernels/cast_compute_test.cc index fea3452dbcf8bc9c46e034f2277fcd6e9d0edca0..a7316a6162ed9a1bbbaf4956d51ab19c017fd3e4 100644 --- a/lite/tests/kernels/cast_compute_test.cc +++ b/lite/tests/kernels/cast_compute_test.cc @@ -22,12 +22,13 @@ namespace lite { class CastComputeTester : public arena::TestCase { protected: - // common attributes for this op. - std::string input_ = "x"; - std::string output_ = "out"; + std::string x_ = "x"; + std::string out_ = "out"; + // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; + // SIZE_T = 19;UINT8 = 20;INT8 = 21; int in_dtype_; int out_dtype_; - DDim x_dims_{{2, 2}}; + DDim dims_{{2, 2}}; public: CastComputeTester(const Place& place, @@ -36,91 +37,148 @@ class CastComputeTester : public arena::TestCase { int out_dtype) : TestCase(place, alias), in_dtype_(in_dtype), out_dtype_(out_dtype) {} - void RunBaseline(Scope* scope) override { - auto* out = scope->NewTensor(output_); + template + void RunBaselineHelper(Scope* scope) { + auto* x = scope->FindTensor(x_); + auto* x_data = x->data(); + auto* out = scope->NewTensor(out_); CHECK(out); - out->Resize(x_dims_); + out->Resize(dims_); + auto* out_data = out->mutable_data(); + for (int i = 0; i < dims_.production(); i++) { + *out_data = static_cast(*x_data); + out_data++; + x_data++; + } + } - if (out_dtype_ == 5 && in_dtype_ == 20) { - auto* x = scope->FindTensor(input_); - auto* x_data = x->data(); - auto* output_data = out->mutable_data(); - for (int i = 0; i < x_dims_.production(); i++) { - *output_data = static_cast(*x_data); - output_data++; - x_data++; - } - } else if (out_dtype_ == 5 && in_dtype_ == 21) { - auto* output_data = out->mutable_data(); - auto* x = scope->FindTensor(input_); - auto* x_data = x->data(); - for (int i = 0; i < x_dims_.production(); i++) { - *output_data = static_cast(*x_data); - output_data++; - x_data++; - } - } else if (out_dtype_ == 5 && in_dtype_ == 2) { - auto* output_data = out->mutable_data(); - auto* x = scope->FindTensor(input_); - auto* x_data = x->data(); - for (int i = 0; i < x_dims_.production(); i++) { - *output_data = static_cast(*x_data); - output_data++; - x_data++; - } + void RunBaseline(Scope* scope) override { + if (in_dtype_ == 20 && out_dtype_ == 5) { + RunBaselineHelper(scope); + } else if (in_dtype_ == 2 && out_dtype_ == 5) { + RunBaselineHelper(scope); + } else if (in_dtype_ == 3 && out_dtype_ == 5) { + RunBaselineHelper(scope); + } else if (in_dtype_ == 5 && out_dtype_ == 3) { + RunBaselineHelper(scope); + } else if (in_dtype_ == 21 && out_dtype_ == 5) { + RunBaselineHelper(scope); + } else if (in_dtype_ == 5 && out_dtype_ == 21) { + RunBaselineHelper(scope); + } else { + LOG(FATAL) << "unsupported"; } } void PrepareOpDesc(cpp::OpDesc* op_desc) { op_desc->SetType("cast"); - op_desc->SetInput("X", {input_}); - op_desc->SetOutput("Out", {output_}); + op_desc->SetInput("X", {x_}); + op_desc->SetOutput("Out", {out_}); op_desc->SetAttr("in_dtype", in_dtype_); op_desc->SetAttr("out_dtype", out_dtype_); } + template + void PrepareDataHelper() { + std::vector x_data(dims_.production()); + for (int i = 0; i < dims_.production(); i++) { + x_data[i] = static_cast(i % 128); + } + SetCommonTensor(x_, dims_, x_data.data()); + } + void PrepareData() override { - SetPrecisionType(output_, PRECISION(kFloat)); - if (in_dtype_ == 20) { - std::vector x_data(x_dims_.production()); - for (int i = 0; i < x_dims_.production(); i++) { - x_data[i] = static_cast(i % 128); - } - SetCommonTensor(input_, x_dims_, x_data.data()); - } else if (in_dtype_ == 21) { - std::vector x_data(x_dims_.production()); - for (int i = 0; i < x_dims_.production(); i++) { - float sign = i % 3 == 0 ? -1.0f : 1.0f; - x_data[i] = sign * static_cast(i % 128); - } - SetCommonTensor(input_, x_dims_, x_data.data()); - } else if (in_dtype_ == 2) { - std::vector x_data(x_dims_.production()); - for (int i = 0; i < x_dims_.production(); i++) { - int sign = i % 3 == 0 ? -1 : 1; - x_data[i] = sign * static_cast(i % 128); - } - SetCommonTensor(input_, x_dims_, x_data.data()); - } else { - LOG(FATAL) << "not implemented!"; + // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; + // SIZE_T = 19;UINT8 = 20;INT8 = 21; + switch (in_dtype_) { + case 20: + PrepareDataHelper(); + break; + case 21: + PrepareDataHelper(); + break; + case 1: + PrepareDataHelper(); + break; + case 2: + PrepareDataHelper(); + break; + case 3: + PrepareDataHelper(); + break; + case 5: + PrepareDataHelper(); + break; + case 6: + PrepareDataHelper(); + break; + case 19: + PrepareDataHelper(); + break; + default: + LOG(FATAL) << "unsupported data type: " << in_dtype_; + break; + } + + PrecisionType out_ptype; + switch (out_dtype_) { + case 0: + out_ptype = PRECISION(kBool); + break; + case 21: + out_ptype = PRECISION(kInt8); + break; + case 1: + out_ptype = PRECISION(kInt16); + break; + case 2: + out_ptype = PRECISION(kInt32); + break; + case 3: + out_ptype = PRECISION(kInt64); + break; + case 4: + out_ptype = PRECISION(kFP16); + break; + case 5: + out_ptype = PRECISION(kFloat); + break; + default: + LOG(FATAL) << "unsupported data type: " << out_dtype_; + break; } + SetPrecisionType(out_, out_ptype); } }; -TEST(Cast, precision) { - LOG(INFO) << "test cast op"; -#ifdef LITE_WITH_ARM - Place place(TARGET(kARM)); - +void TestCast(Place place, float abs_error, int in_dtype, int out_dtype) { std::unique_ptr tester( - new CastComputeTester(place, "def", 20, 5)); - arena::Arena arena(std::move(tester), place, 2e-5); + new CastComputeTester(place, "def", in_dtype, out_dtype)); + arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); +} - std::unique_ptr tester1( - new CastComputeTester(place, "def", 2, 5)); - arena::Arena arena1(std::move(tester1), place, 2e-5); - arena1.TestPrecision(); +TEST(Cast, precision) { + LOG(INFO) << "test cast op"; + Place place; + float abs_error = 2e-5; +#if defined(LITE_WITH_ARM) + place = TARGET(kARM); +#elif defined(LITE_WITH_XPU) + place = TARGET(kXPU); +#else + return; +#endif + +// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6; +// SIZE_T = 19;UINT8 = 20;INT8 = 21; +#ifndef LITE_WITH_XPU + TestCast(place, abs_error, 20, 5); +#endif + TestCast(place, abs_error, 2, 5); +#ifdef LITE_WITH_XPU + TestCast(place, abs_error, 3, 5); + TestCast(place, abs_error, 5, 3); #endif }