提交 137f0dfc 编写于 作者: W wanghaoshuang

1. Fix warpctc grad tensor initial bug.

2. Remove num_seq arguments.
3. Refine CUDA kernel of ScaleLoDTensorFunctor.
4. Change max_relative_error of gradient unitest to 0.007
上级 fd24e195
...@@ -22,10 +22,10 @@ template <typename T> ...@@ -22,10 +22,10 @@ template <typename T>
class ScaleLoDTensorFunctor<platform::CPUDeviceContext, T> { class ScaleLoDTensorFunctor<platform::CPUDeviceContext, T> {
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
framework::LoDTensor& seq, const T* scales, framework::LoDTensor& seq, const T* scales) {
const size_t num_seq) {
const size_t level = 0; const size_t level = 0;
auto lod = seq.lod(); auto lod = seq.lod();
const size_t num_seq = lod[level].size() - 1;
size_t seq_width = seq.dims()[1]; size_t seq_width = seq.dims()[1];
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
......
...@@ -20,18 +20,10 @@ namespace math { ...@@ -20,18 +20,10 @@ namespace math {
template <typename T> template <typename T>
__global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales, __global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales,
const size_t num_seq,
const size_t seq_width) { const size_t seq_width) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x; if (threadIdx.x < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * seq_width) {
int idx = lod[blockIdx.x] * seq_width + threadIdx.x;
if (idx < lod[num_seq] * seq_width) { seq[idx] *= scales[blockIdx.x];
size_t i = 0;
for (i = 0; i < num_seq; ++i) {
if (idx < lod[i + 1] * seq_width) {
break;
}
}
seq[idx] *= scales[i];
} }
} }
...@@ -39,18 +31,17 @@ template <typename T> ...@@ -39,18 +31,17 @@ template <typename T>
class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> { class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
framework::LoDTensor& seq, const T* scales, framework::LoDTensor& seq, const T* scales) {
const size_t num_seq) {
auto lod = seq.lod();
const size_t seq_width = seq.dims()[1];
const size_t level = 0; const size_t level = 0;
auto lod = seq.lod();
const size_t num_seq = lod[level].size() - 1;
const size_t seq_width = seq.numel() / seq.dims()[0];
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod); framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
T* seq_data = seq.mutable_data<T>(context.GetPlace()); T* seq_data = seq.mutable_data<T>(context.GetPlace());
int threads = 1024; int threads = 1024;
int grid = (seq.numel() * seq_width + threads - 1) / threads; SequenceScaleKernel<T><<<num_seq, threads, 0, context.stream()>>>(
SequenceScaleKernel<T><<<grid, threads, 0, context.stream()>>>( seq_data, abs_offset_lod[level].data(), scales, seq_width);
seq_data, abs_offset_lod[level].data(), scales, num_seq, seq_width);
} }
}; };
......
...@@ -47,7 +47,7 @@ template <typename DeviceContext, typename T> ...@@ -47,7 +47,7 @@ template <typename DeviceContext, typename T>
class ScaleLoDTensorFunctor { class ScaleLoDTensorFunctor {
public: public:
void operator()(const DeviceContext& context, framework::LoDTensor& seq, void operator()(const DeviceContext& context, framework::LoDTensor& seq,
const T* scales, const size_t num_seq); const T* scales);
}; };
} // namespace math } // namespace math
......
...@@ -179,6 +179,10 @@ class WarpCTCKernel : public framework::OpKernel<T> { ...@@ -179,6 +179,10 @@ class WarpCTCKernel : public framework::OpKernel<T> {
T* warpctc_grad_data = T* warpctc_grad_data =
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace()); warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), warpctc_grad,
static_cast<T>(0));
// warpctc accesses labels in CPU memory // warpctc accesses labels in CPU memory
Tensor warpctc_label; Tensor warpctc_label;
Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label); Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label);
...@@ -215,10 +219,9 @@ class WarpCTCGradKernel : public framework::OpKernel<T> { ...@@ -215,10 +219,9 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
*warpctc_grad, norm_by_times); *warpctc_grad, norm_by_times);
const T* loss_grad_data = loss_grad->data<T>(); const T* loss_grad_data = loss_grad->data<T>();
const size_t num_seq = loss_grad->dims()[0];
math::ScaleLoDTensorFunctor<DeviceContext, T>()( math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad, ctx.template device_context<DeviceContext>(), *logits_grad,
loss_grad_data, num_seq); loss_grad_data);
} }
}; };
......
...@@ -193,7 +193,7 @@ class TestWarpCTCOp(OpTest): ...@@ -193,7 +193,7 @@ class TestWarpCTCOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss", max_relative_error=0.01) self.check_grad(["Logits"], "Loss", max_relative_error=0.007)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册