未验证 提交 a105bbdf 编写于 作者: W Wilber 提交者: GitHub

[CUDA] Support model run correctly. (#3975)

上级 e45e6dc6
......@@ -30,9 +30,16 @@ namespace lite {
namespace cuda {
namespace math {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
template <typename Dtype>
inline __device__ Dtype Sigmoid(const Dtype a) {
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-a));
const Dtype min = SIGMOID_THRESHOLD_MIN;
const Dtype max = SIGMOID_THRESHOLD_MAX;
Dtype tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<Dtype>(1.0) / (static_cast<Dtype>(1.0) + expf(-tmp));
}
template <>
......@@ -63,6 +70,7 @@ inline __device__ half ReLU(const half a) {
template <typename Dtype>
inline __device__ Dtype Tanh(const Dtype a) {
Dtype tmp = static_cast<Dtype>(-2.0) * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (static_cast<Dtype>(2.0) / (static_cast<Dtype>(1.0) + expf(tmp))) -
static_cast<Dtype>(1.0);
}
......
......@@ -22,10 +22,6 @@ namespace lite {
namespace cuda {
namespace math {
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void scale_kernel(int count,
const T* in_data,
......@@ -48,7 +44,6 @@ __global__ void scale_kernel(int count,
template <typename T>
__global__ void scale_kernel(
int count, const T* in_data, T* out_data, const T scale, const T bias) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
CUDA_KERNEL_LOOP(tid, count) { out_data[tid] = scale * in_data[tid] + bias; }
}
......@@ -133,12 +128,11 @@ void fp32_scale_nhwc(int num,
}
template <typename T>
void scale(int num, const T* in, T* out, T scale, cudaStream_t stream, T bias) {
void scale(int num, const T* in, T* out, T scale, T bias, cudaStream_t stream) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread, 0, stream>>>(num, in, out, scale, bias);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
......@@ -146,11 +140,10 @@ void scale(int num, const T* in, T* out, T scale, T bias) {
int thread = 256;
int block = (num + thread - 1) / thread;
scale_kernel<<<block, thread>>>(num, in, out, scale, bias);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
CUDA_POST_KERNEL_CHECK;
}
template void scale(int num, const float*, float*, float, cudaStream_t, float);
template void scale(int num, const float*, float*, float, float, cudaStream_t);
template void scale(int num, const float*, float*, float, float);
} // namespace math
......
......@@ -32,8 +32,7 @@ void fp32_scale_nhwc(int num,
cudaStream_t stream);
template <typename T>
void scale(
int num, const T* in, T* out, T scale, cudaStream_t stream, T bias = 0);
void scale(int num, const T* in, T* out, T scale, T bias, cudaStream_t stream);
template <typename T>
void scale(int num, const T* in, T* out, T scale, T bias = 0);
......
......@@ -32,7 +32,7 @@ __global__ void CopyMatrixRowsKernel(const T* src,
bool is_src_index) {
int idx = threadIdx.x;
int idy = threadIdx.y;
int row_id = blockDim.y * gridDim.x + idy;
int row_id = blockDim.y * blockIdx.x + idy;
if (row_id < height) {
int src_idx = is_src_index ? index[row_id] : row_id;
int dst_idx = is_src_index ? row_id : index[row_id];
......@@ -72,7 +72,7 @@ void CopyMatrixRowsFunctor<T>::operator()(
dim3 threads(128, 8);
dim3 grids((height + threads.y - 1) / threads.y);
CopyMatrixRowsKernel<T><<<grids, threads, 0, stream>>>(
src_data, dst_data, index_tensor_data, height, width, true);
src_data, dst_data, index_tensor_data, height, width, is_src_index);
CUDA_POST_KERNEL_CHECK;
}
......
......@@ -53,11 +53,11 @@ class LoDTensor2BatchFunctor {
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
struct SeqInfo {
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start_(start), length_(length), seq_idx_(seq_idx) {}
size_t start_;
size_t length_;
size_t seq_idx_;
SeqInfo(size_t start_val, size_t len_val, size_t seq_val)
: start(start_val), length(len_val), seq_idx(seq_val) {}
size_t start;
size_t length;
size_t seq_idx;
};
public:
......@@ -76,7 +76,7 @@ class LoDTensor2BatchFunctor {
}
std::sort(seq_info.begin(), seq_info.end(), [](SeqInfo a, SeqInfo b) {
return a.length_ > b.length_;
return a.length > b.length;
});
// Calculate the start position of each batch.
......@@ -106,7 +106,7 @@ class LoDTensor2BatchFunctor {
batch_lods.emplace_back(std::vector<uint64_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
size_t max_seqlen = seq_info[0].length_;
size_t max_seqlen = seq_info[0].length;
batch_lods[0].resize(max_seqlen + 1);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
......@@ -119,8 +119,8 @@ class LoDTensor2BatchFunctor {
for (size_t n = 0; n < max_seqlen; ++n) {
size_t batch_id = batch_starts[n];
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length_;
size_t start = seq_info[i].start_;
size_t seq_len = seq_info[i].length;
size_t start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
......@@ -133,7 +133,7 @@ class LoDTensor2BatchFunctor {
}
auto* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
seq_order[i] = seq_info[i].seq_idx_;
seq_order[i] = seq_info[i].seq_idx;
}
batch_tensor->set_lod(batch_lods);
......
......@@ -86,8 +86,7 @@ void SequencePadding(T* pad_data,
seq_num,
pad_seq_len,
step_width);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
CUDA_POST_KERNEL_CHECK;
}
template <typename T>
......@@ -120,8 +119,7 @@ void SequenceUnpadding(T* seq_data,
seq_num,
pad_seq_len,
step_width);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
CUDA_POST_KERNEL_CHECK;
}
template void SequencePadding(float* pad_data,
......
......@@ -68,7 +68,7 @@ void AssignValueCompute::Run() {
REGISTER_LITE_KERNEL(assign_value,
kCUDA,
kAny,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::AssignValueCompute,
def)
......
......@@ -23,6 +23,9 @@ namespace cuda {
void DropoutCompute::Run() {
auto& param = Param<operators::DropoutParam>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const float* x_data = param.x->data<float>();
float* out_data = param.output->mutable_data<float>(TARGET(kCUDA));
int num = param.x->dims().production();
......@@ -31,7 +34,7 @@ void DropoutCompute::Run() {
if (param.dropout_implementation == "downgrade_in_infer") {
scale = 1.0f - prob_data;
}
lite::cuda::math::scale(num, x_data, out_data, scale, 0);
lite::cuda::math::scale(num, x_data, out_data, scale, 0.f, stream);
}
} // namespace cuda
......
......@@ -11,6 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/gru_compute.h"
#include <string>
#include "lite/backends/cuda/cuda_utils.h"
......@@ -19,7 +21,6 @@
#include "lite/backends/cuda/math/sequence2batch.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/gru_compute.h"
namespace paddle {
namespace lite {
......@@ -133,7 +134,6 @@ struct GRUUnitFunctor {
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(
......@@ -143,7 +143,7 @@ struct GRUUnitFunctor {
frame_size,
batch_size,
active_gate,
batch_size == 1);
batch_size != 1);
CUDA_POST_KERNEL_CHECK;
if (value.prev_out_value) {
......@@ -163,7 +163,6 @@ struct GRUUnitFunctor {
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
T><<<grids, threads, 0, context->exec_stream()>>>(value.gate_value,
......@@ -173,7 +172,7 @@ struct GRUUnitFunctor {
batch_size,
active_node,
origin_mode,
batch_size == 1);
batch_size != 1);
CUDA_POST_KERNEL_CHECK;
}
};
......@@ -218,7 +217,6 @@ struct GRUUnitFunctor<half> {
value.gate_value,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardResetOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
......@@ -248,7 +246,6 @@ struct GRUUnitFunctor<half> {
value.gate_value + frame_size * 2,
context);
}
CUDA_POST_KERNEL_CHECK;
lite::cuda::math::GruForwardFinalOutput<
half><<<grids, threads, 0, context->exec_stream()>>>(
......
......@@ -23,8 +23,11 @@ namespace cuda {
void ScaleCompute::Run() {
auto& param = Param<operators::ScaleParam>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
const float* x_data = param.x->data<float>();
float* output_data = param.output->mutable_data<float>();
float* output_data = param.output->mutable_data<float>(TARGET(kCUDA));
DDim x_dims = param.x->dims();
bool bias_after_scale = param.bias_after_scale;
float scale = param.scale;
......@@ -33,7 +36,7 @@ void ScaleCompute::Run() {
bias *= scale;
}
lite::cuda::math::scale(
x_dims.production(), x_data, output_data, scale, bias);
x_dims.production(), x_data, output_data, scale, bias, stream);
}
} // namespace cuda
......
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_mask_compute.h"
#include <thrust/device_ptr.h>
#include <thrust/functional.h>
#include <thrust/reduce.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/sequence_mask_compute.h"
namespace paddle {
namespace lite {
......@@ -44,7 +44,7 @@ void SequenceMaskCompute<T, Ptype>::Run() {
auto stream = ctx.exec_stream();
const auto* x = param.X;
auto* x_data = x->template data<int64_t>();
const int64_t* x_data = x->template data<int64_t>();
auto* y = param.Y;
int maxlen = param.maxlen;
......@@ -57,8 +57,11 @@ void SequenceMaskCompute<T, Ptype>::Run() {
}
if (maxlen < 0) {
maxlen = thrust::reduce(
x_data, x_data + x->numel(), 0, thrust::maximum<int64_t>());
maxlen = static_cast<int>(
thrust::reduce(thrust::device_pointer_cast(x_data),
thrust::device_pointer_cast(x_data) + x->numel(),
static_cast<int64_t>(0),
thrust::maximum<int64_t>()));
}
auto y_dim = x->dims().Vectorize();
......
......@@ -32,9 +32,19 @@ void SequencePadCompute<T, Ptype>::Run() {
const auto* pad_value = param.PadValue;
auto* out = param.Out;
auto* len_t = param.Length;
int padded_length = param.padded_length;
int seq_num = x->lod()[0].size() - 1;
int padded_length;
if (param.padded_length == -1) {
int max_seq_len = 0;
for (int i = 0; i < seq_num; ++i) {
max_seq_len = std::max(
max_seq_len, static_cast<int>(x->lod()[0][i + 1] - x->lod()[0][i]));
}
padded_length = max_seq_len;
} else {
padded_length = param.padded_length;
}
int max_seq_len = 0;
int step_width = x->numel() / x->dims()[0];
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include <algorithm>
#include "lite/backends/cuda/math/sequence_padding.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
......@@ -29,8 +30,39 @@ void SequenceUnpadCompute<T, Ptype>::Run() {
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto x_dims = param.X->dims();
auto len_dims = param.Length->dims();
auto* seq_len_ptr = param.Length->template data<int64_t>();
seq_len_cpu_.Resize(param.Length->dims());
TargetWrapperCuda::MemcpyAsync(seq_len_cpu_.mutable_data<int64_t>(),
seq_len_ptr,
sizeof(int64_t) * param.Length->numel(),
IoDirection::DtoH,
stream);
TargetWrapperCuda::StreamSync(stream);
int64_t batch_size = len_dims[0];
std::vector<uint64_t> out_lod0(batch_size + 1, 0);
for (int64_t i = 0; i < batch_size; ++i) {
out_lod0[i + 1] = out_lod0[i] + seq_len_cpu_.data<int64_t>()[i];
}
paddle::lite::LoD out_lod;
out_lod.push_back(out_lod0);
int64_t out_dim0 = out_lod0.back();
std::vector<int64_t> out_dims{out_dim0};
if (x_dims.size() == 2) {
out_dims.push_back(1);
} else {
for (size_t i = 2; i < x_dims.size(); ++i) {
out_dims.push_back(x_dims[i]);
}
}
param.Out->Resize(out_dims);
param.Out->set_lod(out_lod);
const auto* pad_tensor = param.X;
const auto* len_t = param.Length;
auto* seq_tensor = param.Out;
int padded_length = pad_tensor->dims()[1];
......
......@@ -31,6 +31,7 @@ class SequenceUnpadCompute : public KernelLite<TARGET(kCUDA), Ptype> {
private:
lite::Tensor seq_offsets_;
lite::Tensor seq_len_cpu_;
std::vector<size_t> seq_offsets_vec_;
};
......
......@@ -184,6 +184,8 @@ using VarConvFp16 =
REGISTER_LITE_KERNEL(var_conv_2d, kCUDA, kFloat, kNCHW, VarConvFp32, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("COLUMN", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("ROW", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Col", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
......@@ -191,6 +193,9 @@ REGISTER_LITE_KERNEL(var_conv_2d, kCUDA, kFloat, kNCHW, VarConvFp32, def)
REGISTER_LITE_KERNEL(var_conv_2d, kCUDA, kFP16, kNCHW, VarConvFp16, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("COLUMN",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindInput("ROW", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.BindOutput("Col", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))})
.Finalize();
......@@ -75,9 +75,8 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto batch_reset_hidden_prev = op_desc.Output("BatchResetHiddenPrev").front();
auto batch_hidden = op_desc.Output("BatchHidden").front();
auto hidden = op_desc.Output("Hidden").front();
param_.input = scope->FindVar(input)->GetMutable<lite::Tensor>();
if (op_desc.Input("H0").size()) {
if (!op_desc.Input("H0").empty()) {
auto h0 = op_desc.Input("H0").front();
param_.h0 = scope->FindVar(h0)->GetMutable<lite::Tensor>();
}
......@@ -90,7 +89,7 @@ bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
scope->FindVar(batch_hidden)->GetMutable<lite::Tensor>();
param_.hidden = scope->FindVar(hidden)->GetMutable<lite::Tensor>();
if (op_desc.HasInput("Bias")) {
if (!op_desc.Input("Bias").empty()) {
auto bias = op_desc.Input("Bias").front();
param_.bias = scope->FindVar(bias)->GetMutable<lite::Tensor>();
}
......
......@@ -61,18 +61,19 @@ bool SequencePadOp::InferShapeImpl() const {
max_seq_len =
std::max(max_seq_len, static_cast<int>(x_lod_0[i + 1] - x_lod_0[i]));
}
if (param_.padded_length == -1) {
param_.padded_length = max_seq_len;
int real_padded_length = param_.padded_length;
if (real_padded_length == -1) {
real_padded_length = max_seq_len;
}
CHECK_GE(param_.padded_length, max_seq_len)
CHECK_GE(real_padded_length, max_seq_len)
<< "The SequencePadOp Attr(padded_length) should be greater than or "
"equal to the length of the longest original sequence. But the "
"padded_length we received is "
<< param_.padded_length
<< real_padded_length
<< ", the length of the longest original sequence is " << max_seq_len;
int out_dim_0 = seq_num;
std::vector<int64_t> out_dims_vec{out_dim_0, param_.padded_length};
std::vector<int64_t> out_dims_vec{out_dim_0, real_padded_length};
std::vector<int64_t> len_dims_vec{out_dim_0};
auto time_step_dims_vec = time_step_dims.Vectorize();
out_dims_vec.insert(
......@@ -87,7 +88,7 @@ bool SequencePadOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.PadValue = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("PadValue").front())->Get<lite::Tensor>());
param_.Length = scope->FindVar(opdesc.Input("Length").front())
param_.Length = scope->FindVar(opdesc.Output("Length").front())
->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
......
......@@ -32,32 +32,7 @@ bool SequenceUnpadOp::CheckShape() const {
return true;
}
bool SequenceUnpadOp::InferShapeImpl() const {
auto x_dims = param_.X->dims();
auto len_dims = param_.Length->dims();
auto *seq_len_ptr = param_.Length->data<int64_t>();
int64_t batch_size = len_dims[0];
std::vector<uint64_t> out_lod0(batch_size + 1, 0);
for (int64_t i = 0; i < batch_size; ++i) {
out_lod0[i + 1] = out_lod0[i] + seq_len_ptr[i];
}
paddle::lite::LoD out_lod;
out_lod.push_back(out_lod0);
int64_t out_dim0 = out_lod0.back();
std::vector<int64_t> out_dims{out_dim0};
if (x_dims.size() == 2) {
out_dims.push_back(1);
} else {
for (size_t i = 2; i < x_dims.size(); ++i) {
out_dims.push_back(x_dims[i]);
}
}
param_.Out->Resize(out_dims);
param_.Out->set_lod(out_lod);
return true;
}
bool SequenceUnpadOp::InferShapeImpl() const { return true; }
bool SequenceUnpadOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
......
......@@ -26,10 +26,16 @@ bool VarConv2dOp::InferShapeImpl() const { return true; }
bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
// param_.ROW = const_cast<lite::Tensor *>(
// &scope->FindVar(opdesc.Input("ROW").front())->Get<lite::Tensor>());
// param_.COLUMN = const_cast<lite::Tensor *>(
// &scope->FindVar(opdesc.Input("COLUMN").front())->Get<lite::Tensor>());
if (opdesc.HasInput("ROW") && !opdesc.Input("ROW").empty()) {
param_.ROW = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("ROW").front())->Get<lite::Tensor>());
CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null.";
}
if (opdesc.HasInput("COLUMN") && !opdesc.Input("COLUMN").empty()) {
param_.COLUMN = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("COLUMN").front())->Get<lite::Tensor>());
CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null.";
}
param_.W = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("W").front())->Get<lite::Tensor>());
param_.Out =
......@@ -37,8 +43,6 @@ bool VarConv2dOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.Col =
scope->FindVar(opdesc.Output("Col").front())->GetMutable<lite::Tensor>();
CHECK(param_.X) << "X(Input) of VarConv2dOP should not be null.";
// CHECK(param_.ROW) << "Input(ROW) of VarConv2dOP should not be null.";
// CHECK(param_.COLUMN) << "Input(COLUMN) of VarConv2dOP should not be null.";
CHECK(param_.W) << "W(Input) of VarConv2dOP should not be null.";
CHECK(param_.Out) << "Out(Output) of VarConv2dOP should not be null.";
CHECK(param_.Col) << "Col(Output) of VarConv2dOP should not be null.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册