diff --git a/lite/kernels/x86/CMakeLists.txt b/lite/kernels/x86/CMakeLists.txt index 46d32da3872113f265f7ddef2e70f460c0379d13..56441bbc2c0522e83ba731d9d3ed6d2bdbe3227f 100644 --- a/lite/kernels/x86/CMakeLists.txt +++ b/lite/kernels/x86/CMakeLists.txt @@ -5,6 +5,7 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li # lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) +add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} fluid_data_type) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function) @@ -64,7 +65,7 @@ lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_com lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86) lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86) - +lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x86) lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) diff --git a/lite/kernels/x86/cast_compute.cc b/lite/kernels/x86/cast_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..d342056c7f19e9eba0fe16196d772da6bd5fda3c --- /dev/null +++ b/lite/kernels/x86/cast_compute.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/cast_compute.h" + +REGISTER_LITE_KERNEL(cast, + kX86, + kFloat, + kNCHW, + paddle::lite::kernels::x86::CastCompute, + def) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) + .Finalize(); diff --git a/lite/kernels/x86/cast_compute.h b/lite/kernels/x86/cast_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..06e47e9a5023ea149510e8f10bf719cd6a854349 --- /dev/null +++ b/lite/kernels/x86/cast_compute.h @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/core/types.h" +#include "lite/fluid/data_type.h" +#include "lite/fluid/hostdevice.h" +#include "lite/fluid/transform.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +template +struct CastOpTransformFunctor { + HOSTDEVICE OutT operator()(InT in) const { return static_cast(in); } +}; + +template +class CastOpFunctor { + public: + CastOpFunctor(const lite::Tensor* in, + lite::Tensor* out, + const lite::Context& context) + : input(in), output(out), ctx(context) {} + + template + void apply() const { + auto* in_begin = input->data(); + auto numel = input->dims().production(); + auto* in_end = in_begin + numel; + auto* out_begin = output->mutable_data(); + paddle::lite::fluid::Transform trans; + trans( + ctx, in_begin, in_end, out_begin, CastOpTransformFunctor()); + } + + private: + const lite::Tensor* input; + lite::Tensor* output; + const lite::Context& ctx; +}; + +template +class CastCompute : public KernelLite { + public: + using param_t = operators::CastParam; + + void Run() override { + auto param = param_.get_mutable(); + auto& context = ctx_->As(); + auto x = param->X; + auto out = param->Out; + auto out_dtype = param->out_dtype; + paddle::lite::fluid::VisitDataType( + static_cast(out_dtype), + CastOpFunctor(x, out, context)); + } + virtual ~CastCompute() = default; +}; + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/x86/cast_compute_test.cc b/lite/kernels/x86/cast_compute_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7aa52ca6d0dde603357f009220b4a3a53f56833 --- /dev/null +++ b/lite/kernels/x86/cast_compute_test.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/kernels/x86/cast_compute.h" +#include +#include +#include +#include +#include "lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace x86 { + +TEST(cast_x86, retrive_op) { + auto cast = + KernelRegistry::Global().Create("cast"); + ASSERT_FALSE(cast.empty()); + ASSERT_TRUE(cast.front()); +} + +TEST(cast_x86, init) { + CastCompute cast; + ASSERT_EQ(cast.precision(), PRECISION(kFloat)); + ASSERT_EQ(cast.target(), TARGET(kX86)); +} + +TEST(cast_x86, run_test) { + lite::Tensor x, out; + constexpr int batch_size = 1; + std::vector x_shape{batch_size, 1, 3, 3}; + x.Resize(lite::DDim(x_shape)); + + std::vector out_shape{batch_size, 1, 3, 3}; + out.Resize(lite::DDim(out_shape)); + + auto x_data = x.mutable_data(); + auto out_data = out.mutable_data(); + + for (int64_t i = 0; i < x.dims().production(); i++) { + x_data[i] = static_cast(1); + } + + CastCompute cast; + operators::CastParam param; + param.X = &x; + param.Out = &out; + std::unique_ptr ctx(new KernelContext); + ctx->As(); + cast.SetContext(std::move(ctx)); + cast.SetParam(param); + cast.Run(); + + std::vector ref_results = {1, 1, 1, 1, 1, 1, 1, 1, 1}; + for (int i = 0; i < out.dims().production(); i++) { + EXPECT_NEAR(out_data[i], ref_results[i], 1e-5); + } +} + +} // namespace x86 +} // namespace kernels +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(cast, kX86, kFloat, kNCHW, def);