提交 31dc0193 编写于 作者: C chengduoZH

fix ContextProjectFunctor parameter order

上级 e25bfc75
...@@ -88,9 +88,10 @@ template <typename Place, typename T> ...@@ -88,9 +88,10 @@ template <typename Place, typename T>
class ContextProjectFunctor { class ContextProjectFunctor {
public: public:
void operator()(const platform::DeviceContext& context, const LoDTensor& in, void operator()(const platform::DeviceContext& context, const LoDTensor& in,
const Tensor& padding_data, Tensor& col, const Tensor& padding_data, bool padding_trainable,
bool padding_trainable, int context_start, int context_length, const int context_start, const int context_length,
int context_stride, int up_pad, int down_pad) { const int context_stride, const int up_pad,
const int down_pad, Tensor* col) {
auto lod_level_0 = in.lod()[0]; auto lod_level_0 = in.lod()[0];
math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf; math::Im2ColFunctor<math::ColFormat::kOCF, Place, float> im2col_ocf;
...@@ -109,7 +110,7 @@ class ContextProjectFunctor { ...@@ -109,7 +110,7 @@ class ContextProjectFunctor {
: static_cast<int>(lod_level_0[i]); : static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]); input_row_end = static_cast<int>(lod_level_0[i + 1]);
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -133,7 +134,7 @@ class ContextProjectFunctor { ...@@ -133,7 +134,7 @@ class ContextProjectFunctor {
} }
if (padding_trainable) { if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -197,10 +198,11 @@ class ContextProjectFunctor { ...@@ -197,10 +198,11 @@ class ContextProjectFunctor {
template <typename Place, typename T> template <typename Place, typename T>
class ContextProjectGradFunctor { class ContextProjectGradFunctor {
public: public:
void operator()(const platform::DeviceContext& context, LoDTensor& in, void operator()(const platform::DeviceContext& context, const LoDTensor& in,
Tensor& padding_data, Tensor& col, bool padding_trainable, bool padding_trainable, const int context_start,
int context_start, int context_length, int context_stride, const int context_length, const int context_stride,
int up_pad, int down_pad, bool input_grad, bool pad_grad) { const int up_pad, const int down_pad, bool pad_grad,
bool input_grad, Tensor* padding_data, Tensor* col) {
auto lod_level_0 = in.lod()[0]; auto lod_level_0 = in.lod()[0];
math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf; math::Col2ImFunctor<math::ColFormat::kOCF, Place, float> col2im_ocf;
...@@ -220,7 +222,7 @@ class ContextProjectGradFunctor { ...@@ -220,7 +222,7 @@ class ContextProjectGradFunctor {
: static_cast<int>(lod_level_0[i]); : static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]); input_row_end = static_cast<int>(lod_level_0[i + 1]);
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -247,7 +249,7 @@ class ContextProjectGradFunctor { ...@@ -247,7 +249,7 @@ class ContextProjectGradFunctor {
if (pad_grad) { if (pad_grad) {
if (padding_trainable) { if (padding_trainable) {
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor out_t = col.Slice(static_cast<int>(lod_level_0[i]), Tensor out_t = col->Slice(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1])); static_cast<int>(lod_level_0[i + 1]));
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
...@@ -262,7 +264,7 @@ class ContextProjectGradFunctor { ...@@ -262,7 +264,7 @@ class ContextProjectGradFunctor {
k + context_length < up_pad ? context_length : up_pad - k; k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_t.Slice(k * context_length, Tensor out_t_sub = out_t.Slice(k * context_length,
k * context_length + padding_size); k * context_length + padding_size);
Tensor w_sub = padding_data.Slice(k, k + padding_size); Tensor w_sub = padding_data->Slice(k, k + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub); auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub); auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(*context.GetEigenDevice<Place>()) = w_sub_e.device(*context.GetEigenDevice<Place>()) =
...@@ -295,7 +297,7 @@ class ContextProjectGradFunctor { ...@@ -295,7 +297,7 @@ class ContextProjectGradFunctor {
Tensor out_t_sub = out_t.Slice( Tensor out_t_sub = out_t.Slice(
(down_pad_begin_row + t) * context_length - padding_size, (down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length); (down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data.Slice( Tensor w_sub = padding_data->Slice(
up_pad + padding_idx, up_pad + padding_idx + padding_size); up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub); auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub); auto w_sub_e = EigenMatrix<T>::From(w_sub);
......
...@@ -174,10 +174,9 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, ...@@ -174,10 +174,9 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
int data_col_index = int data_col_index =
(((((c * filter_depth + d_off) * filter_height + h_off) * (((((c * filter_depth + d_off) * filter_height + h_off) *
filter_width + filter_width +
w_off) * w_off)));
output_detph + data_col_index =
d_col) * ((data_col_index * output_detph + d_col) * output_height +
output_height +
h_col) * h_col) *
output_width + output_width +
w_col; w_col;
......
...@@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel<T> { ...@@ -62,9 +62,9 @@ class SequenceConvKernel : public framework::OpKernel<T> {
math::ContextProjectFunctor<Place, T> seq_project_functor; math::ContextProjectFunctor<Place, T> seq_project_functor;
seq_project_functor(context.device_context(), *in, *padding_data, col, seq_project_functor(context.device_context(), *in, *padding_data,
padding_trainable, context_start, context_length, padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad); context_stride, up_pad, down_pad, &col);
math::matmul<Place, T>(context.device_context(), col, false, filter, false, math::matmul<Place, T>(context.device_context(), col, false, filter, false,
static_cast<T>(1.0), out, static_cast<T>(0.0)); static_cast<T>(1.0), out, static_cast<T>(0.0));
...@@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -117,10 +117,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
in_g->set_lod(in->lod()); in_g->set_lod(in->lod());
set_zero(context.device_context(), in_g, static_cast<T>(0)); set_zero(context.device_context(), in_g, static_cast<T>(0));
seq_project_grad_functor(context.device_context(), *in_g, *padding_data_g, seq_project_grad_functor(context.device_context(), *in_g,
col, padding_trainable, context_start, padding_trainable, context_start, context_length,
context_length, context_stride, up_pad, down_pad, context_stride, up_pad, down_pad, false, true,
true, false); padding_data_g, &col);
} }
if (padding_trainable && padding_data_g) { if (padding_trainable && padding_data_g) {
...@@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -129,9 +129,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
LoDTensor* input = const_cast<LoDTensor*>(in); LoDTensor* input = const_cast<LoDTensor*>(in);
seq_project_grad_functor(context.device_context(), *input, seq_project_grad_functor(context.device_context(), *input,
*padding_data_g, col, padding_trainable, padding_trainable, context_start, context_length,
context_start, context_length, context_stride, context_stride, up_pad, down_pad, true, false,
up_pad, down_pad, false, true); padding_data_g, &col);
} }
if (filter_g) { if (filter_g) {
...@@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -146,9 +146,9 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
padding_data = context.Input<Tensor>("PaddingData"); padding_data = context.Input<Tensor>("PaddingData");
} }
seq_project_functor(context.device_context(), *in, *padding_data, col, seq_project_functor(context.device_context(), *in, *padding_data,
padding_trainable, context_start, context_length, padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad); context_stride, up_pad, down_pad, &col);
math::matmul<Place, T>(context.device_context(), col, true, out_grad, math::matmul<Place, T>(context.device_context(), col, true, out_grad,
false, T(1.0), &filter_grad, T(1.0)); false, T(1.0), &filter_grad, T(1.0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册