提交 31dc0193 编写于 作者: C chengduoZH

fix ContextProjectFunctor parameter order

上级 e25bfc75
......@@ -88,9 +88,10 @@ template <typename Place, typename T>
class ContextProjectFunctor {
public:
void operator()(const platform::DeviceContext& context, const LoDTensor& in,
const Tensor& padding_data, Tensor& col,
bool padding_trainable, int context_start, int context_length,
int context_stride, int up_pad, int down_pad) {
const Tensor& padding_data, bool padding_trainable,
const int context_start, const int context_length,
const int context_stride, const int up_pad,
const int down_pad, Tensor* col) {
auto lod_level_0 = in.lod()[0];
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
......@@ -109,8 +110,8 @@ class ContextProjectFunctor {
: static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]);
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
......@@ -133,8 +134,8 @@ class ContextProjectFunctor {
}
if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
......@@ -197,10 +198,11 @@ class ContextProjectFunctor {
template <typename Place, typename T>
class ContextProjectGradFunctor {
public:
void operator()(const platform::DeviceContext& context, LoDTensor& in,
Tensor& padding_data, Tensor& col, bool padding_trainable,
int context_start, int context_length, int context_stride,
int up_pad, int down_pad, bool input_grad, bool pad_grad) {
void operator()(const platform::DeviceContext& context, const LoDTensor& in,
bool padding_trainable, const int context_start,
const int context_length, const int context_stride,
const int up_pad, const int down_pad, bool pad_grad,
bool input_grad, Tensor* padding_data, Tensor* col) {
auto lod_level_0 = in.lod()[0];
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
......@@ -220,8 +222,8 @@ class ContextProjectGradFunctor {
: static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]);
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
......@@ -247,8 +249,8 @@ class ContextProjectGradFunctor {
if (pad_grad) {
if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]);
out_t.Resize({sequence_height * context_length, sequence_width});
......@@ -262,7 +264,7 @@ class ContextProjectGradFunctor {
k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size);
Tensor w_sub = padding_data->Slice(k, k + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(*context.GetEigenDevice<Place>()) =
......@@ -295,7 +297,7 @@ class ContextProjectGradFunctor {
Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice(
Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
......
......@@ -174,10 +174,9 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int data_col_index =
(((((c * filter_depth + d_off) * filter_height + h_off) *
filter_width +
w_off) *
output_detph +
d_col) *
output_height +
w_off)));
data_col_index =
((data_col_index * output_detph + d_col) * output_height +
h_col) *
output_width +
w_col;
......
......@@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel<T> {
math::ContextProjectFunctor<Place, T> seq_project_functor;
seq_project_functor(context.device_context(), *in, *padding_data, col,
seq_project_functor(context.device_context(), *in, *padding_data,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad);
context_stride, up_pad, down_pad, &col);
math::matmul<Place, T>(context.device_context(), col, false, filter, false,
static_cast<T>(1.0), out, static_cast<T>(0.0));
......@@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
in_g->set_lod(in->lod());
set_zero(context.device_context(), in_g, static_cast<T>(0));
seq_project_grad_functor(context.device_context(), *in_g, *padding_data_g,
col, padding_trainable, context_start,
context_length, context_stride, up_pad, down_pad,
true, false);
seq_project_grad_functor(context.device_context(), *in_g,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, false, true,
padding_data_g, &col);
}
if (padding_trainable && padding_data_g) {
......@@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
LoDTensor* input = const_cast<LoDTensor*>(in);
seq_project_grad_functor(context.device_context(), *input,
*padding_data_g, col, padding_trainable,
context_start, context_length, context_stride,
up_pad, down_pad, false, true);
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, true, false,
padding_data_g, &col);
}
if (filter_g) {
......@@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
padding_data = context.Input<Tensor>("PaddingData");
}
seq_project_functor(context.device_context(), *in, *padding_data, col,
seq_project_functor(context.device_context(), *in, *padding_data,
padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad);
context_stride, up_pad, down_pad, &col);
math::matmul<Place, T>(context.device_context(), col, true, out_grad,
false, T(1.0), &filter_grad, T(1.0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册