提交 4f917867 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] cast op bridge and ut (#2738)

上级 a0f455ee
...@@ -24,6 +24,7 @@ lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_sub ...@@ -24,6 +24,7 @@ lite_cc_library(subgraph_bridge_reshape_op_xpu SRCS reshape_op.cc DEPS ${xpu_sub
lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_layer_norm_op_xpu SRCS layer_norm_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_dropout_op_xpu SRCS dropout_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps}) lite_cc_library(subgraph_bridge_matmul_op_xpu SRCS matmul_op.cc DEPS ${xpu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_cast_op_xpu SRCS cast_op.cc DEPS ${xpu_subgraph_bridge_deps})
set(xpu_subgraph_bridges set(xpu_subgraph_bridges
subgraph_bridge_registry subgraph_bridge_registry
...@@ -46,6 +47,7 @@ set(xpu_subgraph_bridges ...@@ -46,6 +47,7 @@ set(xpu_subgraph_bridges
subgraph_bridge_layer_norm_op_xpu subgraph_bridge_layer_norm_op_xpu
subgraph_bridge_dropout_op_xpu subgraph_bridge_dropout_op_xpu
subgraph_bridge_matmul_op_xpu subgraph_bridge_matmul_op_xpu
subgraph_bridge_cast_op_xpu
CACHE INTERNAL "xpu_subgraph_bridges") CACHE INTERNAL "xpu_subgraph_bridges")
message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}") message(STATUS "+++++ xpu_subgraph_bridges: ${xpu_subgraph_bridges}")
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/xpu/bridges/graph.h"
#include "lite/kernels/xpu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace xpu {
bool CvtDtype(int dtype, PrecisionType* ptype) {
switch (dtype) {
case 21:
*ptype = PRECISION(kInt8);
break;
case 1:
*ptype = PRECISION(kInt16);
break;
case 2:
*ptype = PRECISION(kInt32);
break;
case 3:
*ptype = PRECISION(kInt64);
break;
case 5:
*ptype = PRECISION(kFloat);
break;
default:
LOG(WARNING) << "[XPU] unsupported date type: " << dtype;
return false;
}
return true;
}
int CastConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[XPU] Converting " + op_type + "...";
// Get input and output vars and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindMutableTensor(x_name);
auto out_name = op_info->Output("Out").front();
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
int in_dtype = op_info->GetAttr<int>("in_dtype");
PrecisionType in_ptype;
if (!CvtDtype(in_dtype, &in_ptype)) {
return FAILED;
}
int out_dtype = op_info->GetAttr<int>("out_dtype");
PrecisionType out_ptype;
if (!CvtDtype(out_dtype, &out_ptype)) {
return FAILED;
}
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x, in_ptype);
}
// Cast node
graph->Add(
out_name,
graph->builder_.CreateCast(*x_node->data(), CvtPrecisionType(out_ptype)));
return SUCCESS;
}
} // namespace xpu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(cast,
kXPU,
paddle::lite::subgraph::xpu::CastConverter);
...@@ -36,3 +36,4 @@ USE_SUBGRAPH_BRIDGE(layer_norm, kXPU); ...@@ -36,3 +36,4 @@ USE_SUBGRAPH_BRIDGE(layer_norm, kXPU);
USE_SUBGRAPH_BRIDGE(gelu, kXPU); USE_SUBGRAPH_BRIDGE(gelu, kXPU);
USE_SUBGRAPH_BRIDGE(dropout, kXPU); USE_SUBGRAPH_BRIDGE(dropout, kXPU);
USE_SUBGRAPH_BRIDGE(matmul, kXPU); USE_SUBGRAPH_BRIDGE(matmul, kXPU);
USE_SUBGRAPH_BRIDGE(cast, kXPU);
...@@ -13,7 +13,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH ...@@ -13,7 +13,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${cuda_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
...@@ -22,12 +22,13 @@ namespace lite { ...@@ -22,12 +22,13 @@ namespace lite {
class CastComputeTester : public arena::TestCase { class CastComputeTester : public arena::TestCase {
protected: protected:
// common attributes for this op. std::string x_ = "x";
std::string input_ = "x"; std::string out_ = "out";
std::string output_ = "out"; // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
int in_dtype_; int in_dtype_;
int out_dtype_; int out_dtype_;
DDim x_dims_{{2, 2}}; DDim dims_{{2, 2}};
public: public:
CastComputeTester(const Place& place, CastComputeTester(const Place& place,
...@@ -36,91 +37,148 @@ class CastComputeTester : public arena::TestCase { ...@@ -36,91 +37,148 @@ class CastComputeTester : public arena::TestCase {
int out_dtype) int out_dtype)
: TestCase(place, alias), in_dtype_(in_dtype), out_dtype_(out_dtype) {} : TestCase(place, alias), in_dtype_(in_dtype), out_dtype_(out_dtype) {}
void RunBaseline(Scope* scope) override { template <typename T1, typename T2>
auto* out = scope->NewTensor(output_); void RunBaselineHelper(Scope* scope) {
auto* x = scope->FindTensor(x_);
auto* x_data = x->data<T1>();
auto* out = scope->NewTensor(out_);
CHECK(out); CHECK(out);
out->Resize(x_dims_); out->Resize(dims_);
auto* out_data = out->mutable_data<T2>();
for (int i = 0; i < dims_.production(); i++) {
*out_data = static_cast<T2>(*x_data);
out_data++;
x_data++;
}
}
if (out_dtype_ == 5 && in_dtype_ == 20) { void RunBaseline(Scope* scope) override {
auto* x = scope->FindTensor(input_); if (in_dtype_ == 20 && out_dtype_ == 5) {
auto* x_data = x->data<unsigned char>(); RunBaselineHelper<uint8_t, float>(scope);
auto* output_data = out->mutable_data<float>(); } else if (in_dtype_ == 2 && out_dtype_ == 5) {
for (int i = 0; i < x_dims_.production(); i++) { RunBaselineHelper<int32_t, float>(scope);
*output_data = static_cast<float>(*x_data); } else if (in_dtype_ == 3 && out_dtype_ == 5) {
output_data++; RunBaselineHelper<int64_t, float>(scope);
x_data++; } else if (in_dtype_ == 5 && out_dtype_ == 3) {
} RunBaselineHelper<float, int64_t>(scope);
} else if (out_dtype_ == 5 && in_dtype_ == 21) { } else if (in_dtype_ == 21 && out_dtype_ == 5) {
auto* output_data = out->mutable_data<float>(); RunBaselineHelper<int8_t, float>(scope);
auto* x = scope->FindTensor(input_); } else if (in_dtype_ == 5 && out_dtype_ == 21) {
auto* x_data = x->data<char>(); RunBaselineHelper<float, int8_t>(scope);
for (int i = 0; i < x_dims_.production(); i++) { } else {
*output_data = static_cast<float>(*x_data); LOG(FATAL) << "unsupported";
output_data++;
x_data++;
}
} else if (out_dtype_ == 5 && in_dtype_ == 2) {
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<int32_t>();
for (int i = 0; i < x_dims_.production(); i++) {
*output_data = static_cast<float>(*x_data);
output_data++;
x_data++;
}
} }
} }
void PrepareOpDesc(cpp::OpDesc* op_desc) { void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("cast"); op_desc->SetType("cast");
op_desc->SetInput("X", {input_}); op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {output_}); op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("in_dtype", in_dtype_); op_desc->SetAttr("in_dtype", in_dtype_);
op_desc->SetAttr("out_dtype", out_dtype_); op_desc->SetAttr("out_dtype", out_dtype_);
} }
template <typename T1>
void PrepareDataHelper() {
std::vector<T1> x_data(dims_.production());
for (int i = 0; i < dims_.production(); i++) {
x_data[i] = static_cast<T1>(i % 128);
}
SetCommonTensor(x_, dims_, x_data.data());
}
void PrepareData() override { void PrepareData() override {
SetPrecisionType(output_, PRECISION(kFloat)); // BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
if (in_dtype_ == 20) { // SIZE_T = 19;UINT8 = 20;INT8 = 21;
std::vector<unsigned char> x_data(x_dims_.production()); switch (in_dtype_) {
for (int i = 0; i < x_dims_.production(); i++) { case 20:
x_data[i] = static_cast<unsigned char>(i % 128); PrepareDataHelper<uint8_t>();
} break;
SetCommonTensor(input_, x_dims_, x_data.data()); case 21:
} else if (in_dtype_ == 21) { PrepareDataHelper<int8_t>();
std::vector<char> x_data(x_dims_.production()); break;
for (int i = 0; i < x_dims_.production(); i++) { case 1:
float sign = i % 3 == 0 ? -1.0f : 1.0f; PrepareDataHelper<int16_t>();
x_data[i] = sign * static_cast<char>(i % 128); break;
} case 2:
SetCommonTensor(input_, x_dims_, x_data.data()); PrepareDataHelper<int32_t>();
} else if (in_dtype_ == 2) { break;
std::vector<int32_t> x_data(x_dims_.production()); case 3:
for (int i = 0; i < x_dims_.production(); i++) { PrepareDataHelper<int64_t>();
int sign = i % 3 == 0 ? -1 : 1; break;
x_data[i] = sign * static_cast<int32_t>(i % 128); case 5:
} PrepareDataHelper<float>();
SetCommonTensor(input_, x_dims_, x_data.data()); break;
} else { case 6:
LOG(FATAL) << "not implemented!"; PrepareDataHelper<double>();
break;
case 19:
PrepareDataHelper<size_t>();
break;
default:
LOG(FATAL) << "unsupported data type: " << in_dtype_;
break;
}
PrecisionType out_ptype;
switch (out_dtype_) {
case 0:
out_ptype = PRECISION(kBool);
break;
case 21:
out_ptype = PRECISION(kInt8);
break;
case 1:
out_ptype = PRECISION(kInt16);
break;
case 2:
out_ptype = PRECISION(kInt32);
break;
case 3:
out_ptype = PRECISION(kInt64);
break;
case 4:
out_ptype = PRECISION(kFP16);
break;
case 5:
out_ptype = PRECISION(kFloat);
break;
default:
LOG(FATAL) << "unsupported data type: " << out_dtype_;
break;
} }
SetPrecisionType(out_, out_ptype);
} }
}; };
TEST(Cast, precision) { void TestCast(Place place, float abs_error, int in_dtype, int out_dtype) {
LOG(INFO) << "test cast op";
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
std::unique_ptr<arena::TestCase> tester( std::unique_ptr<arena::TestCase> tester(
new CastComputeTester(place, "def", 20, 5)); new CastComputeTester(place, "def", in_dtype, out_dtype));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, abs_error);
arena.TestPrecision(); arena.TestPrecision();
}
std::unique_ptr<arena::TestCase> tester1( TEST(Cast, precision) {
new CastComputeTester(place, "def", 2, 5)); LOG(INFO) << "test cast op";
arena::Arena arena1(std::move(tester1), place, 2e-5); Place place;
arena1.TestPrecision(); float abs_error = 2e-5;
#if defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_XPU)
place = TARGET(kXPU);
#else
return;
#endif
// BOOL = 0;INT16 = 1;INT32 = 2;INT64 = 3;FP16 = 4;FP32 = 5;FP64 = 6;
// SIZE_T = 19;UINT8 = 20;INT8 = 21;
#ifndef LITE_WITH_XPU
TestCast(place, abs_error, 20, 5);
#endif
TestCast(place, abs_error, 2, 5);
#ifdef LITE_WITH_XPU
TestCast(place, abs_error, 3, 5);
TestCast(place, abs_error, 5, 3);
#endif #endif
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册