未验证 提交 4ff6bc17 编写于 作者: S Siddharth Goyal 提交者: GitHub

Add row conv operator (#6013)

* Fix documentation

* Address review comments
上级 0ca62744
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/row_conv_op.h"
#include "paddle/framework/eigen.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
class RowConvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of RowConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) of RowConvOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RowConvOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 2, "Input(Y)'s rank should be 2.");
PADDLE_ENFORCE_EQ(
x_dims[1], filter_dims[1],
"The 2nd dimension of Input(X) and Input(Filter) should be same.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", "Out");
}
};
class RowConvGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
"Input(Filter) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of output(Out) should not be null.");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(x_grad_name, x_dims);
}
auto filter_grad_name = framework::GradVarName("Filter");
if (ctx->HasOutput(filter_grad_name)) {
auto filter_dims = ctx->GetInputDim("Filter");
ctx->SetOutputDim(filter_grad_name, filter_dims);
}
}
};
class RowConvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowConvOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor), the input(X) is a LodTensor, which supports "
"variable time-length input sequences. The underlying tensor "
"in this LoDTensor is a matrix with shape (T x N), where T "
"is the total time steps in this mini-batch and N is the input "
"data dimension.");
AddInput("Filter",
"(Tensor), the input(Filter) is a learnable parameter. It "
"is a 2-D tensor with shape (future_context x N), where, "
"future_context is the future context length and N is the data "
"dimension.");
AddOutput("Out",
"(LoDTensor), the output(Out) is a LodTensor, which supports "
"variable time-length input sequences. The underlying tensor "
"in this LodTensor is a matrix with shape T x N, i.e., the "
"same shape as X.");
AddComment(R"DOC(
Row-convolution Operator.
The row convolution is called lookahead convolution. This operator was
introduced in the following paper for DeepSpeech2:
http://www.cs.cmu.edu/~dyogatam/papers/wang+etal.iclrworkshop2016.pdf
The main motivation is that a bidirectional RNN, useful in DeepSpeech
like speech models, learns representation for a sequence by performing a
forward and a backward pass through the entire sequence. However, unlike
unidirectional RNNs, bidirectional RNNs are challenging to deploy in an online
and low-latency setting. The lookahead convolution incorporates information
from future subsequences in a computationally efficient manner to improve
unidirectional recurrent neural networks. The row convolution operator is
different from the 1D sequence convolution, and is computed as follows:
Given an input sequence $in$ of length $t$ and input dimension $d$,
and a filter ($W$) of size $context \times d$,
the output sequence is convolved as:
$$
out_{i, :} = \sum_{j=i}^{i + context} in_{j,:} \dot W_{i-j, :}
$$
)DOC");
}
};
template <typename T>
class RowConvKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<LoDTensor>("X");
auto *filter = context.Input<Tensor>("Filter");
auto *out = context.Output<LoDTensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto batch_indices = x->lod()[0];
auto input_dim = x->dims()[1]; // 'in' is of size T x N
size_t num_sequence = batch_indices.size() - 1;
auto future_context = filter->dims()[0];
auto weights = EigenMatrix<T>::From(*filter);
for (size_t i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start;
Tensor cur_input_sequence =
x->Slice(start, end); // Current input sequence
Tensor cur_output_sequence =
out->Slice(start, end); // Current output sequence
auto cip_seq = EigenMatrix<T>::From(cur_input_sequence);
auto cot_seq = EigenMatrix<T>::From(cur_output_sequence);
for (int k = 0; k < current_timesteps;
k++) { // For different time steps in the same sequence
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
w++) {
for (int d = 0; d < input_dim; d++) {
if (w == 0) {
cot_seq(k, d) = weights(w, d) * cip_seq(k + w, d);
} else {
cot_seq(k, d) += weights(w, d) * cip_seq(k + w, d);
}
}
}
}
}
}
};
template <typename T>
class RowConvGradKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<LoDTensor>("X");
auto *filter = context.Input<Tensor>("Filter");
auto *d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *dx = context.Output<LoDTensor>(framework::GradVarName("X"));
auto *d_filter = context.Output<Tensor>(framework::GradVarName("Filter"));
auto input_dim = x->dims()[1]; // 'x' is of size T x N
auto batch_indices = x->lod()[0];
size_t num_sequence = batch_indices.size() - 1;
auto future_context = filter->dims()[0];
if (d_filter) {
d_filter->mutable_data<T>(context.GetPlace());
auto dweights =
EigenMatrix<T>::From(*d_filter); // Gradient of weight matrix
dweights.setZero();
for (size_t i = 0; i < num_sequence; i++) { // For different sequences
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
Tensor cur_input = x->Slice(start, end); // Current input sequence
Tensor cur_doutput =
d_out->Slice(start, end); // Current output grad sequence
auto cur_ip = EigenMatrix<T>::From(cur_input);
auto cur_dout = EigenMatrix<T>::From(cur_doutput);
int current_timesteps = end - start;
for (int k = 0; k < current_timesteps;
k++) { // For different time steps in the same sequence
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
w++) {
// For dweights (Updating the gradient of weight matrix)
for (int d = 0; d < input_dim; d++) {
dweights(w, d) += cur_ip(k + w, d) * cur_dout(k, d);
}
}
}
}
}
if (dx) {
dx->mutable_data<T>(context.GetPlace());
auto weights = EigenMatrix<T>::From(*filter);
for (size_t i = 0; i < num_sequence; i++) { // For different sequences
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
Tensor cur_doutput =
d_out->Slice(start, end); // Current output grad sequence
Tensor cur_dinput =
dx->Slice(start, end); // Current input grad sequence
auto cur_dout = EigenMatrix<T>::From(cur_doutput);
auto cur_dip = EigenMatrix<T>::From(cur_dinput);
cur_dip.setZero();
int current_timesteps = end - start;
for (int k = 0; k < current_timesteps;
k++) { // For different time steps in the same sequence
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
w++) {
// For dinput (Updating the gradient wrt input)
for (int d = 0; d < input_dim; d++) {
cur_dip(k + w, d) += weights(w, d) * cur_dout(k, d);
}
}
}
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(row_conv, ops::RowConvOp, ops::RowConvOpMaker, row_conv_grad,
ops::RowConvGradOp);
REGISTER_OP_CPU_KERNEL(row_conv,
ops::RowConvKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
row_conv_grad, ops::RowConvGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/row_conv_op.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using framework::Tensor;
namespace {
inline int DivUp(int x, int y) { return (x + y - 1) / y; }
// Forward prop (shared memory version, for small future_context)
template <typename T>
__global__ void RowConvForwardSharedMemory(const T *in, const T *wt,
int num_sequence, int input_dim,
int future_context,
const size_t *batch_indices,
T *out) {
int blx = blockDim.x;
int bly = blockDim.y;
int thx = threadIdx.x;
int thy = threadIdx.y;
int d = blockIdx.x * blx + thx; // index along input dim
extern __shared__ T mem[];
T *sw = mem;
if (thy < future_context) {
sw[thy * blx + thx] =
(d < input_dim) ? wt[thy * input_dim + d] : static_cast<T>(0);
}
__syncthreads();
for (size_t i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0;
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
w++) {
sum += (d < input_dim)
? sw[w * blx + thx] * in[(start + k + w) * input_dim + d]
: static_cast<T>(0);
}
if (d < input_dim) {
out[(start + k) * input_dim + d] = sum;
}
}
}
}
// Forward prop (naive version)
template <typename T>
__global__ void RowConvForward(const T *in, const T *wt, int num_sequence,
int input_dim, int future_context,
const size_t *batch_indices, T *out) {
int d = blockIdx.x * blockDim.x + threadIdx.x; // index along input_dim
int bly = blockDim.y;
int thy = threadIdx.y;
if (d >= input_dim) return;
for (size_t i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0;
for (int w = 0; (w < future_context) && ((k + w) < current_timesteps);
w++) {
sum += (wt[w * input_dim + d] * in[(start + k + w) * input_dim + d]);
}
out[(start + k) * input_dim + d] = sum;
}
}
}
// Compute input gradient (shared memory version, for small future_context)
template <typename T>
__global__ void RowConvGradInputSharedMemory(const T *dout, const T *wt,
int num_sequence, int input_dim,
int future_context,
const size_t *batch_indices,
T *din) {
int blx = blockDim.x;
int bly = blockDim.y;
int thx = threadIdx.x;
int thy = threadIdx.y;
int d = blockIdx.x * blx + thx; // index along input dim
extern __shared__ T mem[];
T *sw = mem;
if (thy < future_context) {
sw[thy * blx + thx] =
(d < input_dim) ? wt[thy * input_dim + d] : static_cast<T>(0);
}
__syncthreads();
for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0;
for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) {
sum += (d < input_dim)
? (sw[w * blx + thx] * dout[(k + start - w) * input_dim + d])
: static_cast<T>(0);
}
if (d < input_dim) {
din[(k + start) * input_dim + d] = sum;
}
}
}
}
// Compute input gradient (Naive version)
template <typename T>
__global__ void RowConvGradInput(const T *dout, const T *wt, int num_sequence,
int input_dim, int future_context,
const size_t *batch_indices, T *din) {
int d = blockIdx.x * blockDim.x + threadIdx.x; // index along input_dim
int bly = blockDim.y;
int thy = threadIdx.y;
if (d >= input_dim) return;
for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start;
for (int k = thy; k < current_timesteps; k += bly) {
T sum = 0;
for (int w = 0; (w < future_context) && ((k - w) >= 0); w++) {
sum += (wt[w * input_dim + d] * dout[(k + start - w) * input_dim + d]);
}
din[(k + start) * input_dim + d] = sum;
}
}
}
// Compute W gradient (small future_context version)
template <typename T>
__global__ void RowConvGradFilterImproved(const T *in, const T *dout,
int num_sequence, int input_dim,
int future_context, int block_x,
int block_y,
const size_t *batch_indices,
T *dfilter) {
int blx = blockDim.x;
int bly = blockDim.y;
int thx = threadIdx.x;
int thy = threadIdx.y;
int gx = blockIdx.x * blx;
int d = gx + thx; // index along input dim
extern __shared__ T mem[];
int xdim_sh_in = block_y;
int xdim_sh_dout = block_y;
// int xdim_sh_dfilter = future_context;
int ydim_sh_in = block_x;
int ydim_sh_dout = block_x + future_context - 1;
int ydim_sh_dfilter = block_y;
T *sh_in = mem;
T *sh_dout = &mem[xdim_sh_in * ydim_sh_in];
T *sh_dfilter = &mem[xdim_sh_in * ydim_sh_in + xdim_sh_dout * ydim_sh_dout];
if (thy < future_context) {
sh_dfilter[thy * ydim_sh_dfilter + thx] = static_cast<T>(0);
}
__syncthreads();
for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start;
int scaled_cur_steps =
((current_timesteps + block_x - 1) / block_x) * block_x;
for (int k = thy; k < scaled_cur_steps; k += block_x) {
int pos = start + k;
sh_in[thx * ydim_sh_in + thy] =
(d < input_dim && pos < end) ? in[pos * input_dim + d] : T(0);
sh_dout[thx * ydim_sh_dout + thy + future_context - 1] =
(d < input_dim && pos < end) ? dout[pos * input_dim + d] : T(0);
__syncthreads();
if (thy < future_context - 1) {
int pos_offset = pos - future_context + 1;
sh_dout[thx * ydim_sh_dout + thy] =
(d < input_dim && pos_offset >= start)
? dout[pos_offset * input_dim + d]
: T(0);
}
__syncthreads();
for (int w = 0; w < future_context; w++) {
T val = sh_in[thy * ydim_sh_in + thx] *
sh_dout[thy * ydim_sh_dout + thx + future_context - 1 - w];
__syncthreads();
for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += __shfl_down(val, offset);
}
__syncthreads();
if (thx == 0) {
sh_dfilter[w * ydim_sh_dfilter + thy] += val;
}
__syncthreads();
}
}
}
for (int w = thy; (w < future_context) && (d < input_dim); w += bly) {
dfilter[w * input_dim + d] += sh_dfilter[w * ydim_sh_dfilter + thx];
}
}
// Compute weight(filter) gradient
template <typename T>
__global__ void RowConvGradFilter(const T *in, const T *dout, int num_sequence,
int input_dim, int future_context,
int block_x, int block_y,
const size_t *batch_indices, T *dfilter) {
int blx = blockDim.x;
int bly = blockDim.y;
int thx = threadIdx.x;
int thy = threadIdx.y;
int gx = blockIdx.x * blx;
int d = gx + thx; // index along input dim
extern __shared__ T mem[];
T *sh_in = mem;
T *sh_dout = &mem[block_x * block_y];
for (int i = 0; i < num_sequence; i++) {
int start = static_cast<int>(batch_indices[i]);
int end = static_cast<int>(batch_indices[i + 1]);
int current_timesteps = end - start;
int scaled_cur_steps =
((current_timesteps + block_x - 1) / block_x) * block_x;
for (int k = thy; k < scaled_cur_steps; k += block_x) {
int pos = start + k;
sh_in[thx * block_y + thy] =
(d < input_dim && pos < end) ? in[pos * input_dim + d] : 0.0;
__syncthreads();
for (int w = 0; w < future_context; w++) {
sh_dout[thx * block_y + thy] =
(d < input_dim && (k - w) >= 0 && (k - w) < current_timesteps)
? dout[(pos - w) * input_dim + d]
: 0.0;
__syncthreads();
T val = sh_in[thy * block_y + thx] * sh_dout[thy * block_y + thx];
__syncthreads();
for (int offset = 16; offset > 0;
offset = offset / 2) { // blockDim.x is 32.
val += __shfl_down(val, offset);
}
__syncthreads();
if (thx == 0 && (gx + thy) < input_dim) {
dfilter[w * input_dim + gx + thy] += val;
}
}
}
}
}
} // namespace
template <typename T>
class RowConvKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<LoDTensor>("X");
auto *Filter = context.Input<Tensor>("Filter");
auto *Out = context.Output<LoDTensor>("Out");
const T *in = X->data<T>();
const T *weight = Filter->data<T>();
T *out = Out->mutable_data<T>(context.GetPlace());
auto batch_indices = X->lod()[0];
int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0];
size_t *idx = batch_indices.data();
auto stream = context.cuda_device_context().stream();
if (future_context <= 32) {
dim3 block_dim = dim3(32, 32);
dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
int mem_per_block = (future_context * block_dim.x) * sizeof(T);
RowConvForwardSharedMemory<
T><<<grid_dim, block_dim, mem_per_block, stream>>>(
in, weight, num_sequence, input_dim, future_context, idx, out);
} else {
dim3 block_dim = dim3(32, 32);
dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
RowConvForward<T><<<grid_dim, block_dim, 0, stream>>>(
in, weight, num_sequence, input_dim, future_context, idx, out);
}
}
};
template <typename T>
class RowConvGradKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *X = context.Input<LoDTensor>("X");
auto *Filter = context.Input<Tensor>("Filter");
auto *dOut = context.Input<LoDTensor>(framework::GradVarName("Out"));
const T *in = X->data<T>();
const T *weights = Filter->data<T>();
const T *dout = dOut->data<T>();
Tensor *dX = context.Output<LoDTensor>(framework::GradVarName("X"));
Tensor *dFilter = context.Output<Tensor>(framework::GradVarName("Filter"));
auto batch_indices = X->lod()[0];
int input_dim = X->dims()[1];
int num_sequence = batch_indices.size() - 1;
int future_context = Filter->dims()[0];
size_t *idx = batch_indices.data();
auto &device_ctx = context.cuda_device_context();
math::SetConstant<platform::GPUPlace, T> zero;
if (dFilter) {
T *dfilter = dFilter->mutable_data<T>(context.GetPlace());
zero(device_ctx, dFilter, static_cast<T>(0.0));
if (future_context <= 32) {
dim3 block_dim = dim3(32, 32);
dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
int block_x = block_dim.x;
int block_y = block_dim.y;
int mem_per_block =
(block_y * block_x + block_y * (block_x + future_context - 1) +
future_context * block_y) *
sizeof(T);
RowConvGradFilterImproved<
T><<<grid_dim, block_dim, mem_per_block, device_ctx.stream()>>>(
in, dout, num_sequence, input_dim, future_context, block_x, block_y,
idx, dfilter);
} else {
dim3 block_dim = dim3(32, 32);
dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
int block_x = block_dim.x;
int block_y = block_dim.y;
int mem_per_block =
(block_x * block_y * 2) * sizeof(T); // For 2 arrays of size 32x32
RowConvGradFilter<
T><<<grid_dim, block_dim, mem_per_block, device_ctx.stream()>>>(
in, dout, num_sequence, input_dim, future_context, block_x, block_y,
idx, dfilter);
}
}
if (dX) {
T *din = dX->mutable_data<T>(context.GetPlace());
if (future_context <= 32) {
dim3 block_dim = dim3(32, 32);
dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
int mem_per_block = (future_context * block_dim.x) * sizeof(T);
RowConvGradInputSharedMemory<
T><<<grid_dim, block_dim, mem_per_block, device_ctx.stream()>>>(
dout, weights, num_sequence, input_dim, future_context, idx, din);
} else {
dim3 block_dim = dim3(32, 32);
dim3 grid_dim = dim3(DivUp(input_dim, block_dim.x), 1);
RowConvGradInput<T><<<grid_dim, block_dim, 0, device_ctx.stream()>>>(
dout, weights, num_sequence, input_dim, future_context, idx, din);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(row_conv,
ops::RowConvKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
row_conv_grad, ops::RowConvGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class RowConvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
template <typename Place, typename T>
class RowConvGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override;
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def row_conv_forward(x, lod, wt):
out = np.zeros_like(x)
seq_info = lod[0]
num_sequences = len(seq_info) - 1
context_length = wt.shape[0]
for i in range(num_sequences): # loop over number of sequences
start = seq_info[i]
end = seq_info[i + 1]
curinput = x[start:end, :]
curoutput = out[start:end, :]
cur_timesteps = end - start
for j in range(cur_timesteps): # loop over different timesteps
for k in range(context_length):
if j + k >= cur_timesteps:
continue
curoutput[j, :] += curinput[j + k, :] * wt[k, :]
return out
class TestRowConvOp1(OpTest):
def setUp(self):
self.op_type = "row_conv"
lod = [[0, 2, 5, 7]]
T = lod[0][-1]
D = 16
context_length = 2
x = np.random.random((T, D)).astype("float32")
wt = np.random.random((context_length, D)).astype("float32")
self.inputs = {'X': (x, lod), 'Filter': wt}
out = row_conv_forward(x, lod, wt)
self.outputs = {'Out': (out, lod)}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.05)
def test_check_grad_ignore_x(self):
self.check_grad(
['Filter'], 'Out', max_relative_error=0.05, no_grad_set=set('X'))
def test_check_grad_ignore_wt(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Filter'))
class TestRowConvOp2(OpTest):
def setUp(self):
self.op_type = "row_conv"
lod = [[0, 20, 50, 100]]
T = lod[0][-1]
D = 35
context_length = 35
x = np.random.random((T, D)).astype("float32")
wt = np.random.random((context_length, D)).astype("float32")
self.inputs = {'X': (x, lod), 'Filter': wt}
out = row_conv_forward(x, lod, wt)
self.outputs = {'Out': (out, lod)}
def test_check_output(self):
self.check_output()
#max_relative_error is increased from 0.05 to 0.06 as for higher
#dimensional input, the dX on CPU for some values has max_rel_error
#slightly more than 0.05
def test_check_grad_normal(self):
self.check_grad(['X', 'Filter'], 'Out', max_relative_error=0.06)
def test_check_grad_ignore_x(self):
self.check_grad(
['Filter'], 'Out', max_relative_error=0.06, no_grad_set=set('X'))
def test_check_grad_ignore_wt(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.06, no_grad_set=set('Filter'))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册