未验证 提交 5bcbacb7 编写于 作者: L liu zhengxi 提交者: GitHub

Add cast op for x86 platform on Paddle-Lite (#2413)

* add cast op for x86 platform

* alter the struct to class to hide the data and alter the pointer
上级 8a1d942a
...@@ -5,6 +5,7 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li ...@@ -5,6 +5,7 @@ add_kernel(activation_compute_x86 X86 basic SRCS activation_compute.cc DEPS ${li
# lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps}) # lite_cc_library(fc_compute_x86 SRCS fc_compute.cc DEPS ${lite_kernel_deps})
add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_x86 X86 basic SRCS scale_compute.cc DEPS ${lite_kernel_deps})
add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} fluid_data_type)
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps}) add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps}) add_kernel(squeeze_compute_x86 X86 basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function) add_kernel(fill_constant_batch_size_like_compute_x86 X86 basic SRCS fill_constant_batch_size_like_compute.cc DEPS ${lite_kernel_deps} math_function)
...@@ -64,7 +65,7 @@ lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_com ...@@ -64,7 +65,7 @@ lite_cc_test(test_gelu_compute_x86 SRCS gelu_compute_test.cc DEPS activation_com
lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86) lite_cc_test(test_sequence_expand_as_compute_x86 SRCS sequence_expand_as_compute_test.cc DEPS sequence_expand_as_compute_x86)
lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86) lite_cc_test(test_gru_compute_x86 SRCS gru_compute_test.cc DEPS gru_compute_x86)
lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86) lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_compute_x86)
lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x86)
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86) lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86) lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86) lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/cast_compute.h"
REGISTER_LITE_KERNEL(cast,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::CastCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
#include "lite/fluid/data_type.h"
#include "lite/fluid/hostdevice.h"
#include "lite/fluid/transform.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename InT, typename OutT>
struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <lite::TargetType Target, typename InT>
class CastOpFunctor {
public:
CastOpFunctor(const lite::Tensor* in,
lite::Tensor* out,
const lite::Context<Target>& context)
: input(in), output(out), ctx(context) {}
template <typename OutT>
void apply() const {
auto* in_begin = input->data<InT>();
auto numel = input->dims().production();
auto* in_end = in_begin + numel;
auto* out_begin = output->mutable_data<OutT>();
paddle::lite::fluid::Transform<lite::TargetType::kX86> trans;
trans(
ctx, in_begin, in_end, out_begin, CastOpTransformFunctor<InT, OutT>());
}
private:
const lite::Tensor* input;
lite::Tensor* output;
const lite::Context<Target>& ctx;
};
template <typename InT>
class CastCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::CastParam;
void Run() override {
auto param = param_.get_mutable<param_t>();
auto& context = ctx_->As<X86Context>();
auto x = param->X;
auto out = param->Out;
auto out_dtype = param->out_dtype;
paddle::lite::fluid::VisitDataType(
static_cast<framework::proto::VarType::Type>(out_dtype),
CastOpFunctor<lite::TargetType::kX86, InT>(x, out, context));
}
virtual ~CastCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/cast_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(cast_x86, retrive_op) {
auto cast =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>("cast");
ASSERT_FALSE(cast.empty());
ASSERT_TRUE(cast.front());
}
TEST(cast_x86, init) {
CastCompute<float> cast;
ASSERT_EQ(cast.precision(), PRECISION(kFloat));
ASSERT_EQ(cast.target(), TARGET(kX86));
}
TEST(cast_x86, run_test) {
lite::Tensor x, out;
constexpr int batch_size = 1;
std::vector<int64_t> x_shape{batch_size, 1, 3, 3};
x.Resize(lite::DDim(x_shape));
std::vector<int64_t> out_shape{batch_size, 1, 3, 3};
out.Resize(lite::DDim(out_shape));
auto x_data = x.mutable_data<float>();
auto out_data = out.mutable_data<int32_t>();
for (int64_t i = 0; i < x.dims().production(); i++) {
x_data[i] = static_cast<float>(1);
}
CastCompute<float> cast;
operators::CastParam param;
param.X = &x;
param.Out = &out;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
cast.SetContext(std::move(ctx));
cast.SetParam(param);
cast.Run();
std::vector<int32_t> ref_results = {1, 1, 1, 1, 1, 1, 1, 1, 1};
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_results[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(cast, kX86, kFloat, kNCHW, def);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册