未验证 提交 448fee3d 编写于 作者: W whs 提交者: GitHub

Merge pull request #7414 from wanghaoshuang/warpctc

Adapt warpctc grad op for gradient checking
......@@ -149,7 +149,7 @@ op_library(sequence_pool_op DEPS sequence_pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding math_function)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale math_function)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
......
......@@ -13,6 +13,7 @@ if(WITH_GPU)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor)
nv_library(sequence_padding SRCS sequence_padding.cc sequence_padding.cu DEPS lod_tensor device_context)
nv_library(sequence_scale SRCS sequence_scale.cc sequence_scale.cu DEPS lod_tensor device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
......@@ -29,6 +30,7 @@ else()
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor)
cc_library(sequence_padding SRCS sequence_padding.cc DEPS lod_tensor device_context)
cc_library(sequence_scale SRCS sequence_scale.cc DEPS lod_tensor device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_scale.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class ScaleLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
framework::LoDTensor& seq, const T* scales) {
const size_t level = 0;
auto lod = seq.lod();
const size_t num_seq = lod[level].size() - 1;
size_t seq_width = seq.dims()[1];
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
T* seq_data = seq.mutable_data<T>(context.GetPlace());
for (size_t i = 0; i < num_seq; ++i) {
for (size_t j = lod[level][i] * seq_width;
j < lod[level][i + 1] * seq_width; ++j) {
seq_data[j] *= scales[i];
}
}
}
};
template class ScaleLoDTensorFunctor<platform::CPUDeviceContext, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_scale.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
using platform::PADDLE_CUDA_NUM_THREADS;
template <typename T, int BlockSize>
__global__ void SequenceScaleKernel(T* seq, size_t* lod, const T* scales,
const size_t seq_width) {
for (int i = threadIdx.x;
i < (lod[blockIdx.x + 1] - lod[blockIdx.x]) * seq_width;
i += BlockSize) {
int idx = lod[blockIdx.x] * seq_width + i;
seq[idx] *= scales[blockIdx.x];
}
}
template <typename T>
class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
framework::LoDTensor& seq, const T* scales) {
const size_t level = 0;
auto lod = seq.lod();
const size_t num_seq = lod[level].size() - 1;
const size_t seq_width = seq.numel() / seq.dims()[0];
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
T* seq_data = seq.mutable_data<T>(context.GetPlace());
SequenceScaleKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_seq, PADDLE_CUDA_NUM_THREADS, 0, context.stream()>>>(
seq_data, abs_offset_lod[level].data(), scales, seq_width);
}
};
template class ScaleLoDTensorFunctor<platform::CUDADeviceContext, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/*
* \brief Scale a sequence.
*
* All sequences will be padded to the same length and stored in a transposed
* shape.
* Example:
* Given:
* seq = (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* scales = (2, 3, 4, 5)
* then:
* result = (2*s0, 2*s0, 2*s0, 2*s0; 3*s1, 3*s1; 4*s2, 4*s2, 4*s2; 5*s3)
*
* \param context Device context of this functor.
* \param seq LoDTensor which is stored in sequence format, the shape
* is [total_sequence_length, sequence_width] where
* total_sequence_length is the sum of all sequences'
* length.
* \param scales Array<T>. The i-th sequence will be scaled by scales[i].
* \param num_seq Number of sequence
*
*/
template <typename DeviceContext, typename T>
class ScaleLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, framework::LoDTensor& seq,
const T* scales);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_padding.h"
#include "paddle/operators/math/sequence_scale.h"
#include "paddle/platform/dynload/warpctc.h"
namespace paddle {
......@@ -178,11 +179,14 @@ class WarpCTCKernel : public framework::OpKernel<T> {
T* warpctc_grad_data =
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), warpctc_grad,
static_cast<T>(0));
// warpctc accesses labels in CPU memory
Tensor warpctc_label;
Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory
Tensor warpctc_loss;
T* warpctc_loss_data =
......@@ -206,11 +210,18 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto* warpctc_grad = ctx.Input<Tensor>("WarpCTCGrad");
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad,
*warpctc_grad, norm_by_times);
const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad,
loss_grad_data);
}
};
......
......@@ -17,6 +17,8 @@ import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
CUDA_BLOCK_SIZE = 512
class CTCForward(object):
def __init__(self, softmax, softmax_lod, labels, labels_lod, blank,
......@@ -167,47 +169,63 @@ class CTCForward(object):
class TestWarpCTCOp(OpTest):
def config(self):
self.batch_size = 4
self.num_classes = 8
self.logits_lod = [[0, 4, 5, 8, 11]]
self.labels_lod = [[0, 3, 4, 8, 12]]
self.blank = self.num_classes - 1
self.norm_by_times = False
def setUp(self):
self.op_type = "warpctc"
self.config()
batch_size = 4
num_classes = 8
logits_lod = [[0, 4, 5, 8, 11]]
logits = np.random.uniform(0.1, 1.0,
[11, num_classes]).astype("float32")
logits = np.random.uniform(
0.1, 1.0,
[self.logits_lod[0][-1], self.num_classes]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels_lod = [[0, 3, 4, 8, 12]]
# labels should not be blank
labels = np.random.randint(0, num_classes - 1, [12, 1], dtype="int32")
blank = num_classes - 1
norm_by_times = False
labels = np.random.randint(
0, self.num_classes - 1, [self.labels_lod[0][-1], 1], dtype="int32")
ctc = CTCForward(softmax, logits_lod, labels, labels_lod, blank,
norm_by_times)
ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.blank, self.norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
for i in range(batch_size):
max_sequence_length = max(max_sequence_length,
logits_lod[0][i + 1] - logits_lod[0][i])
gradient = np.zeros(
[max_sequence_length, batch_size, num_classes], dtype="float32")
for i in range(self.batch_size):
max_sequence_length = max(
max_sequence_length,
self.logits_lod[0][i + 1] - self.logits_lod[0][i])
self.gradient = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes],
dtype="float32")
self.inputs = {
"Logits": (logits, logits_lod),
"Label": (labels, labels_lod)
"Logits": (logits, self.logits_lod),
"Label": (labels, self.labels_lod)
}
self.outputs = {"Loss": loss}
self.attrs = {"blank": blank, "norm_by_times": norm_by_times}
self.attrs = {"blank": self.blank, "norm_by_times": self.norm_by_times}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss", max_relative_error=0.007)
class TestWarpCTCOpCase1(TestWarpCTCOp):
def config(self):
self.batch_size = 4
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_lod = [[0, 4, 5, 8, 11]]
self.labels_lod = [[0, 3, 4, 8, 12]]
self.blank = 0
self.norm_by_times = False
# def test_check_grad(self):
# self.outputs["WarpCTCGrad"] = None
# self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册