提交 df66957e 编写于 作者: X xutianbing

clean a little bit code.

上级 86fa8c05
...@@ -232,7 +232,7 @@ public: ...@@ -232,7 +232,7 @@ public:
/// input grad and output grad have the same batch_size /// input grad and output grad have the same batch_size
CHECK_EQ(inouts[0].dims_[0], inputs[1].dims_[0]); CHECK_EQ(inouts[0].dims_[0], inputs[1].dims_[0]);
/// dim of output = dim of input * context_length /// dim of output = dim of input * context_length
CHECK_EQ(inputs[1].dims_[1], inputs[0].dims_[1] * context_length_); CHECK_EQ(inputs[1].dims_[1], inouts[0].dims_[1] * context_length_);
typename SequenceT<Device>::type seq_vec( typename SequenceT<Device>::type seq_vec(
inputs[0].dims_[0], reinterpret_cast<int*>(inputs[0].getData())); inputs[0].dims_[0], reinterpret_cast<int*>(inputs[0].getData()));
......
...@@ -256,7 +256,7 @@ __global__ void KeContextProjectionBackwardWeight(const real* out_grad, ...@@ -256,7 +256,7 @@ __global__ void KeContextProjectionBackwardWeight(const real* out_grad,
for (int seqId = idy; seqId < num_sequences; seqId += THREADS_Y) { for (int seqId = idy; seqId < num_sequences; seqId += THREADS_Y) {
int seq_start = sequence[seqId]; int seq_start = sequence[seqId];
int seq_end = sequence[seqId+1]; int seq_end = sequence[seqId+1];
output_r = const_cast<real*>(out_grad) output_r = const_cast<real*>(out_grad)
+ seq_start * w_dim * context_length; + seq_start * w_dim * context_length;
if (context_start < 0) { if (context_start < 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册