提交 1d0f0431 编写于 作者: D Double_V 提交者: Bai Yifan

fix row_conv_op to force it support lodtensor and tensor input simultaneously,...

fix row_conv_op to force it support lodtensor and tensor input simultaneously, test=develop (#19412)

Support Tensor input for row_conv_op
上级 1ce0a09e
...@@ -167,7 +167,7 @@ paddle.fluid.layers.nce (ArgSpec(args=['input', 'label', 'num_total_classes', 's ...@@ -167,7 +167,7 @@ paddle.fluid.layers.nce (ArgSpec(args=['input', 'label', 'num_total_classes', 's
paddle.fluid.layers.sampled_softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_customized_samples', 'customized_samples', 'customized_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0)), ('document', 'd4435a63d34203339831ee6a86ef9242')) paddle.fluid.layers.sampled_softmax_with_cross_entropy (ArgSpec(args=['logits', 'label', 'num_samples', 'num_true', 'remove_accidental_hits', 'use_customized_samples', 'customized_samples', 'customized_probabilities', 'seed'], varargs=None, keywords=None, defaults=(1, True, False, None, None, 0)), ('document', 'd4435a63d34203339831ee6a86ef9242'))
paddle.fluid.layers.hsigmoid (ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)), ('document', 'b83e7dfa81059b39bb137922dc914f50')) paddle.fluid.layers.hsigmoid (ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)), ('document', 'b83e7dfa81059b39bb137922dc914f50'))
paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)), ('document', '1270395ce97a4e1b556104abbb14f096')) paddle.fluid.layers.beam_search (ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False)), ('document', '1270395ce97a4e1b556104abbb14f096'))
paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', '17485788fffe4e2d36dc58c2ac8d174e')) paddle.fluid.layers.row_conv (ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)), ('document', '1d8a1c8b686b55631ba1b77805e4eacf'))
paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23')) paddle.fluid.layers.multiplex (ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None), ('document', '2c4d1ae83da6ed35e3b36ba1b3b51d23'))
paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', '79797f827d89ae72c77960e9696883a9')) paddle.fluid.layers.layer_norm (ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)), ('document', '79797f827d89ae72c77960e9696883a9'))
paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '96b24820e8863d6044d5be4eaaddb9fd')) paddle.fluid.layers.group_norm (ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None)), ('document', '96b24820e8863d6044d5be4eaaddb9fd'))
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -43,13 +42,7 @@ class RowConvOp : public framework::OperatorWithKernel { ...@@ -43,13 +42,7 @@ class RowConvOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2."); PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2.");
if (ctx->IsRuntime() || (x_dims[1] > 0 && filter_dims[1] > 0)) {
PADDLE_ENFORCE_EQ(
x_dims[1], filter_dims[1],
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
}
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
...@@ -84,11 +77,12 @@ class RowConvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -84,11 +77,12 @@ class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", AddInput("X",
"the input(X) is a LodTensor, which supports " "the input(X) is a LodTensor or tensor, LodTensor(X) supports "
"variable time-length input sequences. The underlying tensor " "variable time-length input sequences. The underlying tensor "
"in this LoDTensor is a matrix with shape (T x N), where T " "in this LoDTensor is a matrix with shape (T x N), where T "
"is the total time steps in this mini-batch and N is the input " "is the total time steps in this mini-batch and N is the input "
"data dimension."); "data dimension. the shape of Tensor input(X) has shape "
"(B x T x N), B is batch size;");
AddInput("Filter", AddInput("Filter",
"the input(Filter) is a learnable parameter. It " "the input(Filter) is a learnable parameter. It "
"is a 2-D tensor with shape (future_context x N), where, " "is a 2-D tensor with shape (future_context x N), where, "
...@@ -152,8 +146,26 @@ class RowConvKernel<platform::CPUDeviceContext, T> ...@@ -152,8 +146,26 @@ class RowConvKernel<platform::CPUDeviceContext, T>
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto batch_indices = x->lod()[0]; bool is_tensor = x->lod().empty();
auto input_dim = x->dims()[1]; // 'in' is of size T x N int batch_size = 0;
if (is_tensor) {
batch_size = x->dims()[0];
} else {
batch_size = x->lod()[0].size() - 1;
}
framework::Vector<size_t> batch_indices(batch_size + 1);
int input_dim = 0;
int timesteps = 0;
if (is_tensor) {
for (int i = 0; i < batch_size + 1; i++) {
batch_indices[i] = i;
}
input_dim = x->dims()[2];
timesteps = x->dims()[1];
} else {
batch_indices = x->lod()[0];
input_dim = x->dims()[1];
}
size_t num_sequence = batch_indices.size() - 1; size_t num_sequence = batch_indices.size() - 1;
auto future_context = filter->dims()[0]; auto future_context = filter->dims()[0];
...@@ -162,11 +174,23 @@ class RowConvKernel<platform::CPUDeviceContext, T> ...@@ -162,11 +174,23 @@ class RowConvKernel<platform::CPUDeviceContext, T>
for (size_t i = 0; i < num_sequence; i++) { for (size_t i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start; int current_timesteps = 0;
if (is_tensor) {
current_timesteps = timesteps;
} else {
current_timesteps = end - start;
}
// int current_timesteps = end - start;
Tensor cur_input_sequence = Tensor cur_input_sequence =
x->Slice(start, end); // Current input sequence x->Slice(start, end); // Current input sequence
cur_input_sequence =
cur_input_sequence.Resize({current_timesteps, input_dim});
Tensor cur_output_sequence = Tensor cur_output_sequence =
out->Slice(start, end); // Current output sequence out->Slice(start, end); // Current output sequence
cur_output_sequence =
cur_output_sequence.Resize({current_timesteps, input_dim});
auto cip_seq = EigenMatrix<T>::From(cur_input_sequence); auto cip_seq = EigenMatrix<T>::From(cur_input_sequence);
auto cot_seq = EigenMatrix<T>::From(cur_output_sequence); auto cot_seq = EigenMatrix<T>::From(cur_output_sequence);
...@@ -198,11 +222,30 @@ class RowConvGradKernel<platform::CPUDeviceContext, T> ...@@ -198,11 +222,30 @@ class RowConvGradKernel<platform::CPUDeviceContext, T>
auto *dx = context.Output<LoDTensor>(framework::GradVarName("X")); auto *dx = context.Output<LoDTensor>(framework::GradVarName("X"));
auto *d_filter = context.Output<Tensor>(framework::GradVarName("Filter")); auto *d_filter = context.Output<Tensor>(framework::GradVarName("Filter"));
auto input_dim = x->dims()[1]; // 'x' is of size T x N auto &x_lod = x->lod();
auto batch_indices = x->lod()[0]; bool is_tensor = x_lod.empty();
int batch_size = 0;
if (is_tensor) {
batch_size = x->dims()[0];
} else {
batch_size = x->lod()[0].size() - 1;
}
framework::Vector<size_t> batch_indices(batch_size + 1);
int timesteps = 0;
int input_dim = 0;
if (is_tensor) {
for (int i = 0; i < batch_size + 1; i++) {
batch_indices[i] = i;
}
input_dim = x->dims()[2];
timesteps = x->dims()[1];
} else {
batch_indices = x->lod()[0];
input_dim = x->dims()[1];
}
size_t num_sequence = batch_indices.size() - 1; size_t num_sequence = batch_indices.size() - 1;
auto future_context = filter->dims()[0]; auto future_context = filter->dims()[0];
if (d_filter) { if (d_filter) {
d_filter->mutable_data<T>(context.GetPlace()); d_filter->mutable_data<T>(context.GetPlace());
auto dweights = auto dweights =
...@@ -213,14 +256,19 @@ class RowConvGradKernel<platform::CPUDeviceContext, T> ...@@ -213,14 +256,19 @@ class RowConvGradKernel<platform::CPUDeviceContext, T>
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = 0;
if (is_tensor) {
current_timesteps = timesteps;
} else {
current_timesteps = end - start;
}
Tensor cur_input = x->Slice(start, end); // Current input sequence Tensor cur_input = x->Slice(start, end); // Current input sequence
cur_input = cur_input.Resize({current_timesteps, input_dim});
Tensor cur_doutput = Tensor cur_doutput =
d_out->Slice(start, end); // Current output grad sequence d_out->Slice(start, end); // Current output grad sequence
cur_doutput = cur_doutput.Resize({current_timesteps, input_dim});
auto cur_ip = EigenMatrix<T>::From(cur_input); auto cur_ip = EigenMatrix<T>::From(cur_input);
auto cur_dout = EigenMatrix<T>::From(cur_doutput); auto cur_dout = EigenMatrix<T>::From(cur_doutput);
int current_timesteps = end - start;
for (int k = 0; k < current_timesteps; for (int k = 0; k < current_timesteps;
k++) { // For different time steps in the same sequence k++) { // For different time steps in the same sequence
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
...@@ -241,15 +289,23 @@ class RowConvGradKernel<platform::CPUDeviceContext, T> ...@@ -241,15 +289,23 @@ class RowConvGradKernel<platform::CPUDeviceContext, T>
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = 0;
if (is_tensor) {
current_timesteps = timesteps;
} else {
current_timesteps = end - start;
}
Tensor cur_doutput = Tensor cur_doutput =
d_out->Slice(start, end); // Current output grad sequence d_out->Slice(start, end); // Current output grad sequence
cur_doutput = cur_doutput.Resize({current_timesteps, input_dim});
Tensor cur_dinput = Tensor cur_dinput =
dx->Slice(start, end); // Current input grad sequence dx->Slice(start, end); // Current input grad sequence
cur_dinput = cur_dinput.Resize({current_timesteps, input_dim});
auto cur_dout = EigenMatrix<T>::From(cur_doutput); auto cur_dout = EigenMatrix<T>::From(cur_doutput);
auto cur_dip = EigenMatrix<T>::From(cur_dinput); auto cur_dip = EigenMatrix<T>::From(cur_dinput);
cur_dip.setZero(); cur_dip.setZero();
int current_timesteps = end - start;
for (int k = 0; k < current_timesteps; for (int k = 0; k < current_timesteps;
k++) { // For different time steps in the same sequence k++) { // For different time steps in the same sequence
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -47,11 +46,11 @@ __global__ void RowConvForwardSharedMemory(const T *in, const T *wt, ...@@ -47,11 +46,11 @@ __global__ void RowConvForwardSharedMemory(const T *in, const T *wt,
(d < input_dim) ? wt[thy * input_dim + d] : static_cast<T>(0); (d < input_dim) ? wt[thy * input_dim + d] : static_cast<T>(0);
} }
__syncthreads(); __syncthreads();
for (size_t i = 0; i < num_sequence; i++) { for (size_t i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start; int current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) { for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0; T sum = 0;
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
...@@ -77,11 +76,11 @@ __global__ void RowConvForward(const T *in, const T *wt, int num_sequence, ...@@ -77,11 +76,11 @@ __global__ void RowConvForward(const T *in, const T *wt, int num_sequence,
int thy = threadIdx.y; int thy = threadIdx.y;
if (d >= input_dim) return; if (d >= input_dim) return;
for (size_t i = 0; i < num_sequence; i++) { for (size_t i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start; int current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) { for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0; T sum = 0;
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps); for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
...@@ -114,10 +113,12 @@ __global__ void RowConvGradInputSharedMemory(const T *dout, const T *wt, ...@@ -114,10 +113,12 @@ __global__ void RowConvGradInputSharedMemory(const T *dout, const T *wt,
} }
__syncthreads(); __syncthreads();
int current_timesteps = 0;
for (int i = 0; i < num_sequence; i++) { for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start; current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) { for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0; T sum = 0;
for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) { for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) {
...@@ -142,10 +143,13 @@ __global__ void RowConvGradInput(const T *dout, const T *wt, int num_sequence, ...@@ -142,10 +143,13 @@ __global__ void RowConvGradInput(const T *dout, const T *wt, int num_sequence,
int thy = threadIdx.y; int thy = threadIdx.y;
if (d >= input_dim) return; if (d >= input_dim) return;
int current_timesteps = 0;
for (int i = 0; i < num_sequence; i++) { for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start; current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) { for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0; T sum = 0;
for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) { for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) {
...@@ -175,7 +179,6 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, ...@@ -175,7 +179,6 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
int xdim_sh_in = block_y; int xdim_sh_in = block_y;
int xdim_sh_dout = block_y; int xdim_sh_dout = block_y;
// int xdim_sh_dfilter = future_context;
int ydim_sh_in = block_x; int ydim_sh_in = block_x;
int ydim_sh_dout = block_x + future_context - 1; int ydim_sh_dout = block_x + future_context - 1;
int ydim_sh_dfilter = block_y; int ydim_sh_dfilter = block_y;
...@@ -197,6 +200,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout, ...@@ -197,6 +200,7 @@ __global__ void RowConvGradFilterImproved(const T *in, const T *dout,
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start; int current_timesteps = end - start;
int scaled_cur_steps = int scaled_cur_steps =
((current_timesteps + block_x - 1) / block_x) * block_x; ((current_timesteps + block_x - 1) / block_x) * block_x;
...@@ -258,11 +262,11 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence, ...@@ -258,11 +262,11 @@ __global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
// NOTE(zcd): temporary solution // NOTE(zcd): temporary solution
unsigned mask = 0u; unsigned mask = 0u;
CREATE_SHFL_MASK(mask, true); CREATE_SHFL_MASK(mask, true);
for (int i = 0; i < num_sequence; i++) { for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]); int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]); int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start; int current_timesteps = end - start;
int scaled_cur_steps = int scaled_cur_steps =
((current_timesteps + block_x - 1) / block_x) * block_x; ((current_timesteps + block_x - 1) / block_x) * block_x;
...@@ -310,9 +314,26 @@ class RowConvKernel<platform::CUDADeviceContext, T> ...@@ -310,9 +314,26 @@ class RowConvKernel<platform::CUDADeviceContext, T>
const T *in = X->data<T>(); const T *in = X->data<T>();
const T *weight = Filter->data<T>(); const T *weight = Filter->data<T>();
T *out = Out->mutable_data<T>(context.GetPlace()); T *out = Out->mutable_data<T>(context.GetPlace());
bool is_tensor = X->lod().empty();
int batch_size = 0;
if (is_tensor) {
batch_size = X->dims()[0];
} else {
batch_size = X->lod()[0].size() - 1;
}
int input_dim = 0;
framework::Vector<size_t> batch_indices(batch_size + 1);
int timesteps = X->dims()[1];
if (is_tensor) {
for (int i = 0; i < batch_size + 1; i++) {
batch_indices[i] = i * timesteps;
}
input_dim = X->dims()[2];
} else {
batch_indices = X->lod()[0];
input_dim = X->dims()[1];
}
auto batch_indices = X->lod()[0];
int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
...@@ -348,9 +369,27 @@ class RowConvGradKernel<platform::CUDADeviceContext, T> ...@@ -348,9 +369,27 @@ class RowConvGradKernel<platform::CUDADeviceContext, T>
Tensor *dX = context.Output<LoDTensor>(framework::GradVarName("X")); Tensor *dX = context.Output<LoDTensor>(framework::GradVarName("X"));
Tensor *dFilter = context.Output<Tensor>(framework::GradVarName("Filter")); Tensor *dFilter = context.Output<Tensor>(framework::GradVarName("Filter"));
int batch_size = 0;
bool is_tensor = X->lod().empty();
if (is_tensor) {
batch_size = X->dims()[0];
} else {
batch_size = X->lod()[0].size() - 1;
}
auto batch_indices = X->lod()[0]; int input_dim = 0;
int input_dim = X->dims()[1]; framework::Vector<size_t> batch_indices(batch_size + 1);
int timesteps = X->dims()[1];
if (is_tensor) {
for (int i = 0; i < batch_size + 1; i++) {
batch_indices[i] = i * timesteps;
}
input_dim = X->dims()[2];
} else {
batch_indices = X->lod()[0];
input_dim = X->dims()[1];
}
// int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1; int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0]; int future_context = Filter->dims()[0];
size_t *idx = batch_indices.CUDAMutableData(context.GetPlace()); size_t *idx = batch_indices.CUDAMutableData(context.GetPlace());
......
...@@ -94,7 +94,7 @@ class TestRowConvOp2(OpTest): ...@@ -94,7 +94,7 @@ class TestRowConvOp2(OpTest):
self.check_output() self.check_output()
#max_relative_error is increased from 0.05 to 0.06 as for higher #max_relative_error is increased from 0.05 to 0.06 as for higher
#dimensional input, the dX on CPU for some values has max_rel_error #dimensional input, the dX on CPU for some values has max_rel_error
#slightly more than 0.05 #slightly more than 0.05
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.06) self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.06)
...@@ -108,5 +108,52 @@ class TestRowConvOp2(OpTest): ...@@ -108,5 +108,52 @@ class TestRowConvOp2(OpTest):
['X'], 'Out', max_relative_error=0.06, no_grad_set=set('Filter')) ['X'], 'Out', max_relative_error=0.06, no_grad_set=set('Filter'))
def row_conv_foward_Tensor(x, wt):
out = np.zeros_like(x)
num_sequence = x.shape[0]
timesteps = x.shape[1]
context_length = wt.shape[0]
for i in range(num_sequence):
cur_in = x[i:i + 1, :][0]
cur_out = out[i:i + 1, :][0]
for j in range(timesteps):
for k in range(context_length):
if j + k >= timesteps:
continue
cur_out[j, :] += cur_in[j + k, :] * wt[k, :]
return out
class TestRowOpWithTensorInput(OpTest):
def setUp(self):
self.op_type = "row_conv"
length = [3, 2, 4]
B = 2
T = sum(length)
D = 16
context_length = 2
x = np.random.random((B, T, D)).astype("float32")
wt = np.random.random((context_length, D)).astype("float32")
self.inputs = {'X': x, 'Filter': wt}
out = row_conv_foward_Tensor(x, wt)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad_ignore_x(self):
self.check_grad(
['Filter'], 'Out', max_relative_error=0.05, no_grad_set=set('X'))
def test_check_grad_normal(self):
self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.05)
def test_check_grad_ignore_wt(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Filter'))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册