未验证 提交 5ab1d7e7 编写于 作者: W Wilber 提交者: GitHub

[CUDA] [Kernel] Add assign_value cuda kernel. (#3861)

上级 126691f0
......@@ -47,6 +47,7 @@ add_kernel(search_aligned_mat_mul_compute_cuda CUDA extra SRCS search_aligned_ma
add_kernel(search_seq_fc_compute_cuda CUDA extra SRCS search_seq_fc_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
add_kernel(var_conv_2d_compute_cuda CUDA extra SRCS var_conv_2d_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(topk_pooling_compute_cuda CUDA extra SRCS topk_pooling_compute.cu DEPS ${lite_kernel_deps})
add_kernel(assign_value_compute_cuda CUDA extra SRCS assign_value_compute.cu DEPS ${lite_kernel_deps})
# unit test
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
......@@ -85,4 +86,5 @@ if(LITE_BUILD_EXTRA)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
#nv_test(search_fc_cuda_test SRCS search_fc_compute_test.cc DEPS search_fc_compute_cuda)
nv_test(topk_pooling_compute_cuda_test SRCS topk_pooling_compute_test.cc DEPS topk_pooling_compute_cuda)
nv_test(assign_value_compute_cuda_test SRCS assign_value_compute_test.cc DEPS assign_value_compute_cuda)
endif()
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
#include "lite/kernels/cuda/assign_value_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <class T>
void TensorFromVector(const std::vector<T>& src,
lite::Tensor* dst,
cudaStream_t* stream) {
auto* src_ptr = static_cast<const void*>(src.data());
auto* dst_ptr = static_cast<void*>(dst->mutable_data<T>(TARGET(kCUDA)));
auto size = src.size() * sizeof(T);
TargetWrapperCuda::MemcpyAsync(
dst_ptr, src_ptr, size, IoDirection::HtoD, *stream);
}
void AssignValueCompute::Run() {
auto& param = Param<operators::AssignValueParam>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int dtype = param.dtype;
std::vector<float> fp32_values = param.fp32_values;
std::vector<int> int32_values = param.int32_values;
std::vector<int64_t> int64_values = param.int64_values;
std::vector<int> bool_values = param.bool_values;
auto* out = param.Out;
if (dtype == static_cast<int>(lite::core::FluidType::INT32)) {
TensorFromVector(int32_values, out, &stream);
} else if (dtype == static_cast<int>(lite::core::FluidType::FP32)) {
TensorFromVector(fp32_values, out, &stream);
} else if (dtype == static_cast<int>(lite::core::FluidType::INT64)) {
TensorFromVector(int64_values, out, &stream);
} else if (dtype == static_cast<int>(lite::core::FluidType::BOOL)) {
TensorFromVector(bool_values, out, &stream);
} else {
LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype;
}
return;
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(assign_value,
kCUDA,
kAny,
kNCHW,
paddle::lite::kernels::cuda::AssignValueCompute,
def)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kAny))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class AssignValueCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::AssignValueParam;
void Run() override;
virtual ~AssignValueCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/assign_value_compute.h"
#include <gtest/gtest.h>
#include <functional>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class AssignValueTest : public ::testing::Test {
protected:
AssignValueTest() : dtype(5), shape({1}) {
int num =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
fp32_values.resize(num);
int32_values.resize(num);
int64_values.resize(num);
bool_values.resize(num);
for (int i = 0; i < num; ++i) {
fp32_values[i] = i + 5;
int32_values[i] = i;
int64_values[i] = i;
bool_values[i] = i;
}
std::vector<int64_t> out_shape(shape.size(), 0);
for (size_t i = 0; i < shape.size(); ++i) out_shape[i] = shape[i];
Out_ref.Resize(lite::DDim(out_shape));
Out_gpu.Resize(Out_ref.dims());
Out_cpu.Resize(Out_ref.dims());
cpu_base(&Out_ref);
device_init();
}
void device_init() {
ctx.reset(new KernelContext);
cudaStreamCreate(&stream);
param.shape = shape;
param.dtype = dtype;
param.fp32_values = fp32_values;
param.int32_values = int32_values;
param.int64_values = int64_values;
param.bool_values = bool_values;
param.Out = &Out_gpu;
}
void float_data_init() {}
void half_data_init() {}
void cpu_base(lite::Tensor* Out) {
if (dtype == static_cast<int>(lite::core::FluidType::INT32)) {
for (size_t i = 0; i < int32_values.size(); ++i) {
Out->mutable_data<int>()[i] = int32_values[i];
}
} else if (dtype == static_cast<int>(lite::core::FluidType::FP32)) {
for (size_t i = 0; i < fp32_values.size(); ++i) {
Out->mutable_data<float>()[i] = fp32_values[i];
}
} else if (dtype == static_cast<int>(lite::core::FluidType::INT64)) {
for (size_t i = 0; i < int64_values.size(); ++i) {
Out->mutable_data<int64_t>()[i] = int64_values[i];
}
} else if (dtype == static_cast<bool>(lite::core::FluidType::BOOL)) {
for (size_t i = 0; i < bool_values.size(); ++i) {
Out->mutable_data<bool>()[i] = bool_values[i];
}
} else {
LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype;
}
}
int dtype;
std::vector<int> shape;
std::vector<float> fp32_values;
std::vector<int> int32_values;
std::vector<int64_t> int64_values;
std::vector<int> bool_values;
lite::Tensor Out_ref;
lite::Tensor Out_gpu;
lite::Tensor Out_cpu;
operators::AssignValueParam param;
std::unique_ptr<KernelContext> ctx;
cudaStream_t stream;
};
TEST_F(AssignValueTest, fp32) {
float_data_init();
auto& context = ctx->As<CUDAContext>();
context.SetExecStream(stream);
AssignValueCompute kernel;
kernel.SetParam(param);
kernel.SetContext(std::move(ctx));
for (int i = 0; i < FLAGS_warmup; ++i) {
kernel.Launch();
cudaDeviceSynchronize();
}
auto start = GetCurrentUS();
kernel.PrepareForRun();
for (int i = 0; i < FLAGS_repeats; ++i) {
kernel.Run();
}
cudaDeviceSynchronize();
auto duration = (GetCurrentUS() - start) / 1000.0;
LOG(INFO) << "fp32, warmup: " << FLAGS_warmup
<< ", repeats: " << FLAGS_repeats << ", spend "
<< duration / FLAGS_repeats << " ms in average.";
CopySync<TARGET(kCUDA)>(Out_cpu.mutable_data<float>(),
Out_gpu.data<float>(),
sizeof(float) * Out_gpu.numel(),
IoDirection::DtoH);
for (int i = 0; i < Out_gpu.numel(); ++i) {
EXPECT_NEAR(Out_cpu.data<float>()[i], Out_ref.data<float>()[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -26,12 +26,15 @@ bool AssignValueOpLite::CheckShape() const {
auto shape = param_.shape;
auto int32_values = param_.int32_values;
auto fp32_values = param_.fp32_values;
auto int64_values = param_.int64_values;
auto bool_values = param_.bool_values;
size_t shape_num = 1;
for (int i = 0; i < shape.size(); i++) {
for (size_t i = 0; i < shape.size(); i++) {
shape_num *= shape[i];
}
CHECK_OR_FALSE(shape_num == int32_values.size() ||
shape_num == fp32_values.size());
CHECK_OR_FALSE(
shape_num == int32_values.size() || shape_num == fp32_values.size() ||
shape_num == int64_values.size() || shape_num == bool_values.size());
return true;
}
......@@ -47,9 +50,18 @@ bool AssignValueOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
param_.shape = op_desc.GetAttr<std::vector<int>>("shape");
param_.dtype = op_desc.GetAttr<int>("dtype");
param_.fp32_values = op_desc.GetAttr<std::vector<float>>("fp32_values");
param_.int32_values = op_desc.GetAttr<std::vector<int>>("int32_values");
if (op_desc.HasAttr("fp32_values")) {
param_.fp32_values = op_desc.GetAttr<std::vector<float>>("fp32_values");
}
if (op_desc.HasAttr("int32_values")) {
param_.int32_values = op_desc.GetAttr<std::vector<int>>("int32_values");
}
if (op_desc.HasAttr("int64_values")) {
param_.int64_values = op_desc.GetAttr<std::vector<int64_t>>("int64_values");
}
if (op_desc.HasAttr("bool_values")) {
param_.bool_values = op_desc.GetAttr<std::vector<int>>("bool_values");
}
auto out = op_desc.Output("Out").front();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
return true;
......
......@@ -1338,6 +1338,8 @@ struct AssignValueParam : ParamBase {
int dtype{};
std::vector<float> fp32_values{};
std::vector<int> int32_values{};
std::vector<int64_t> int64_values{};
std::vector<int> bool_values{};
lite::Tensor* Out{};
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册