提交 99c6f44a 编写于 作者: C chengduoZH

follow comments

上级 dcb3da59
...@@ -128,7 +128,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) ...@@ -128,7 +128,7 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(sum_op DEPS net_op) op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling) op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling) op_library(pool_with_index_op DEPS pooling)
op_library(sequence_conv_op DEPS sequence_project) op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute) op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
......
...@@ -9,7 +9,7 @@ if(WITH_GPU) ...@@ -9,7 +9,7 @@ if(WITH_GPU)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(sequence_project SRCS sequence_project.cc sequence_project.cu DEPS device_context) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
else() else()
...@@ -19,7 +19,7 @@ else() ...@@ -19,7 +19,7 @@ else()
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(sequence_project SRCS sequence_project.cc DEPS device_context) cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
endif() endif()
......
...@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,14 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/sequence_project.h" #include "paddle/operators/math/context_project.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template class SequenceProjectFunctor<platform::CPUPlace, float>; template class ContextProjectFunctor<platform::CPUPlace, float>;
template class SequenceProjectFunctor<platform::CPUPlace, double>; template class ContextProjectFunctor<platform::CPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -14,14 +14,14 @@ limitations under the License. */ ...@@ -14,14 +14,14 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/math/sequence_project.h" #include "paddle/operators/math/context_project.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template class SequenceProjectFunctor<platform::GPUPlace, float>; template class ContextProjectFunctor<platform::GPUPlace, float>;
template class SequenceProjectFunctor<platform::GPUPlace, double>; template class ContextProjectFunctor<platform::GPUPlace, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -23,31 +23,29 @@ namespace paddle { ...@@ -23,31 +23,29 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// template <typename T, int MajorType = Eigen::RowMajor,
// typename IndexType = Eigen::DenseIndex>
// using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
/* /*
* \brief SequenceProject projects features of context_length time-steps of each * \brief Context projection concatenate features in adjacent time steps in
* instance. * a sequence. The i-th row of the output is the concatenation of
* * context_length rows of the input. The context_length rows are the
* consecutive rows from the i+shift_start row.
* \param in Input data. * \param in Input data.
* \param inShape The shape of Input data, * \param Shape The shape of Input data,
* [minibatch, number_of_input_features]. * [minibatch, number_of_input_features].
* \param inShape A float LoDTensor. * \param type A float LoDTensor.
* *
* \param padding_data Padding data. * \param padding_data Padding data.
* \param inShape The shape of Padding data, * \param Shape The shape of Padding data,
* [up_pad + down_pad, number_of_input_features]. * [up_pad + down_pad, number_of_input_features].
* \param inShape A float LoDTensor. * \param type A float Tensor.
* *
* \param col Col data. * \param col Col data.
* \param inShape The shape of Col data, * \param Shape The shape of Col data,
* [minibatch, 1]. * [minibatch, context_length * number_of_input_features].
* \param inShape A float LoDTensor. * \param type A float Tensor.
* *
* For a mini-batch of 2 variable lengths sentences, containing 3, and 1 * For a mini-batch of 2 variable lengths sentences, containing 3, and 1
* time-steps: * time-steps:
...@@ -87,7 +85,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; ...@@ -87,7 +85,7 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
*/ */
template <typename Place, typename T> template <typename Place, typename T>
class SequenceProjectFunctor { class ContextProjectFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
framework::LoDTensor& in, framework::Tensor& padding_data, framework::LoDTensor& in, framework::Tensor& padding_data,
...@@ -147,8 +145,7 @@ class SequenceProjectFunctor { ...@@ -147,8 +145,7 @@ class SequenceProjectFunctor {
/*stride_height*/ context_stride, /*stride_width*/ 1, /*stride_height*/ context_stride, /*stride_width*/ 1,
up_pad, down_pad, 0, 0); up_pad, down_pad, 0, 0);
} }
out_t.Resize(framework::make_ddim( out_t.Resize({sequence_height, context_length * sequence_width});
{sequence_height, context_length * sequence_width}));
} }
} }
} }
...@@ -162,8 +159,7 @@ class SequenceProjectFunctor { ...@@ -162,8 +159,7 @@ class SequenceProjectFunctor {
sequence_height = static_cast<int>(out_t.dims()[0]); sequence_height = static_cast<int>(out_t.dims()[0]);
// add up trainable data // add up trainable data
out_t.Resize(framework::make_ddim( out_t.Resize({sequence_height * context_length, sequence_width});
{sequence_height * context_length, sequence_width}));
if (up_pad > 0) { // add up pad if (up_pad > 0) { // add up pad
int padding_rows = std::min( int padding_rows = std::min(
...@@ -223,8 +219,7 @@ class SequenceProjectFunctor { ...@@ -223,8 +219,7 @@ class SequenceProjectFunctor {
} }
} }
} }
out_t.Resize(framework::make_ddim( out_t.Resize({sequence_height, context_length * sequence_width});
{sequence_height, context_length * sequence_width}));
} }
} }
} }
......
...@@ -38,10 +38,9 @@ class SequenceConvOp : public framework::OperatorWithKernel { ...@@ -38,10 +38,9 @@ class SequenceConvOp : public framework::OperatorWithKernel {
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2, PADDLE_ENFORCE(in_dims.size() == 2 && filter_dims.size() == 2,
"Input(X, Filter) should be 2-D tensor."); "Input(X, Filter) should be 2-D tensor.");
PADDLE_ENFORCE( PADDLE_ENFORCE(filter_dims[0] == context_length * in_dims[1],
filter_dims[0] == context_length && filter_dims[1] == in_dims[1], "Filter's height should be context_length * "
"Filter's shape should be (context_length x " "number_of_input_features .");
"number_of_input_features).");
if (padding_trainable) { if (padding_trainable) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -66,8 +65,9 @@ class SequenceConvOp : public framework::OperatorWithKernel { ...@@ -66,8 +65,9 @@ class SequenceConvOp : public framework::OperatorWithKernel {
"and 'context_length'."); "and 'context_length'.");
} }
in_dims[1] = 1; in_dims[1] = filter_dims[1];
ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Out", in_dims);
ctx->ShareLoD("X", "Out");
} }
}; };
...@@ -101,35 +101,51 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -101,35 +101,51 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
SequenceConvOpMaker(framework::OpProto* proto, SequenceConvOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker) framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput(
"(A float LoDTensor) the input of SequenceConvOp, a vector of " "X",
"2-D matrix of size (minibatch, number_of_input_features)."); "(LoDTensor) the input(X) is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, D), where, T is the "
"total time steps in this mini-batch, D is the input feature size.");
AddInput("PaddingData", AddInput("PaddingData",
"(Tensor) the input of SequenceConvOp, a vector of " "(Tensor, optional) the input(PaddingData) is an optional "
"2-D matrix of size (up_pad + down_pad, " "parameter, and it is learnable. "
"number_of_input_features). ") "This is a tensor with shape (N, D), where N is the "
"top_pad + bottom_pad, D is the input feature size. In order to "
"ensure the equal length of sequence before and after "
"convolution, it is necessary to fill the top and bottom of each "
"sequence according to context_length, context_stride and "
"context_start")
.AsDispensable(); .AsDispensable();
AddInput("Filter", AddInput("Filter",
"(Tensor) the input of SequenceConvOp, a vector of " "(Tensor) the input(Filter) is an learnable parameter."
"2-D matrix of size (context_length x number_of_input_features)."); "This is a tensor with shape (N, D), where N is the "
AddOutput("Out", "context_length, D is the output feature size.");
"(A float LoDTensor) the output of SequenceConvOp, a vector " AddOutput(
"of 2-D matrix of size (minibatch, 1)."); "Out",
"(LoDTensor) the output(Out) is a LodTensor, which support "
"variable-time length output sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T, D), where, T is the "
"total time steps in this mini-batch, D is the output feature size.");
AddAttr<bool>("padding_trainable", AddAttr<bool>("padding_trainable",
"(bool, default false) the padding data of SequenceConvOp " "(bool, default false) the padding data of SequenceConvOp "
"is trainable or not.") "is trainable or not.")
.SetDefault(false); .SetDefault(false);
AddAttr<int>("context_length", AddAttr<int>("context_length",
"(int, default 3) the context_length of SequenceConvOp.") "(int, default 3) the context_length of SequenceConvOp is the "
"height of the convolution kernel.")
.SetDefault(3) .SetDefault(3)
.GreaterThan(0); .GreaterThan(0);
AddAttr<int>("context_start", AddAttr<int>("context_start",
"(int, default 0) the context_start of SequenceConvOp.") "(int, default 0) the context_start of SequenceConvOp "
"represents the beginning of the convolution of the number of "
"rows of sequence, which can be negative.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("context_stride", AddAttr<int>("context_stride",
"(int, default 1) the context_stride of SequenceConvOp. " "(int, default 1) the context_stride of SequenceConvOp "
"Currently, sequence_project_op only support " "represents the step length of convolution. "
"Currently, SequenceConvOp only supports"
"context_stride=1.") "context_stride=1.")
.SetDefault(1) .SetDefault(1)
.GreaterThan(0); .GreaterThan(0);
...@@ -139,14 +155,10 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -139,14 +155,10 @@ class SequenceConvOpMaker : public framework::OpProtoAndCheckerMaker {
context_length time-steps of each instance. context_length time-steps of each instance.
The convolution operation calculates the output based on the input, filter The convolution operation calculates the output based on the input, filter
and strides, paddings parameters. The size of each dimension of the and strides, paddings parameters. The size of each dimension of the
parameters is checked in the infer-shape. parameters is checked in the infer-shape. In order to ensure the equal
length of sequence before and after convolution, it is necessary to fill
Example: the top and bottom of each sequence according to context_length,
Input: context_stride and context_start.
X shape: (minibatch, number_of_input_features)
Filter shape: (context_length, number_of_input_features)
Output:
Out shape: (minibatch, 1)
)DOC"); )DOC");
} }
}; };
......
...@@ -15,20 +15,14 @@ limitations under the License. */ ...@@ -15,20 +15,14 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/context_project.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_project.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
// template <typename T, int MajorType = Eigen::RowMajor,
// typename IndexType = Eigen::DenseIndex>
// using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SequenceConvKernel : public framework::OpKernel<T> { class SequenceConvKernel : public framework::OpKernel<T> {
...@@ -39,7 +33,7 @@ class SequenceConvKernel : public framework::OpKernel<T> { ...@@ -39,7 +33,7 @@ class SequenceConvKernel : public framework::OpKernel<T> {
auto filter = *context.Input<Tensor>("Filter"); auto filter = *context.Input<Tensor>("Filter");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
// out->set_lod(in->lod()); context.ShareLoD("X", "Out");
int context_start = context.Attr<int>("context_start"); int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length"); int context_length = context.Attr<int>("context_length");
...@@ -60,17 +54,16 @@ class SequenceConvKernel : public framework::OpKernel<T> { ...@@ -60,17 +54,16 @@ class SequenceConvKernel : public framework::OpKernel<T> {
int sequence_width; int sequence_width;
sequence_width = static_cast<int>(in->dims()[1]); sequence_width = static_cast<int>(in->dims()[1]);
// use col_shape in the im2col calculation // Use col_shape in the im2col calculation.
framework::DDim col_shape = {in->dims()[0], framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length}; sequence_width * context_length};
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
math::SetConstant<Place, T> set_zero;
// Because if padding_trainable is false, padding data should be zeros. // Because if padding_trainable is false, padding data should be zeros.
auto temp = framework::EigenVector<T>::Flatten(col); set_zero(context.device_context(), &col, static_cast<T>(0));
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
paddle::operators::math::SequenceProjectFunctor<Place, T> paddle::operators::math::ContextProjectFunctor<Place, T>
seq_project_functor; seq_project_functor;
LoDTensor* input = const_cast<LoDTensor*>(in); LoDTensor* input = const_cast<LoDTensor*>(in);
Tensor* pad_data = const_cast<Tensor*>(padding_data); Tensor* pad_data = const_cast<Tensor*>(padding_data);
...@@ -79,9 +72,8 @@ class SequenceConvKernel : public framework::OpKernel<T> { ...@@ -79,9 +72,8 @@ class SequenceConvKernel : public framework::OpKernel<T> {
padding_trainable, context_start, context_length, padding_trainable, context_start, context_length,
context_stride, up_pad, down_pad, false, false, false); context_stride, up_pad, down_pad, false, false, false);
filter.Resize(framework::make_ddim({context_length * sequence_width, 1}));
math::matmul<Place, T>(context.device_context(), col, false, filter, false, math::matmul<Place, T>(context.device_context(), col, false, filter, false,
T(1.0), out, T(0.0)); static_cast<T>(1.0), out, static_cast<T>(0.0));
} }
}; };
...@@ -102,7 +94,6 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -102,7 +94,6 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
int context_stride = context.Attr<int>("context_stride"); int context_stride = context.Attr<int>("context_stride");
bool padding_trainable = context.Attr<bool>("padding_trainable"); bool padding_trainable = context.Attr<bool>("padding_trainable");
// InferShape by in_lod
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL, PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now."); "Only support one level sequence now.");
auto lod_g_level_0 = in->lod()[0]; auto lod_g_level_0 = in->lod()[0];
...@@ -111,6 +102,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -111,6 +102,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
int down_pad = std::max(0, context_start + context_length - 1); int down_pad = std::max(0, context_start + context_length - 1);
int sequence_width = static_cast<int>(in->dims()[1]); int sequence_width = static_cast<int>(in->dims()[1]);
math::SetConstant<Place, T> set_zero;
// use col_shape in the im2col calculation // use col_shape in the im2col calculation
framework::DDim col_shape = {in->dims()[0], framework::DDim col_shape = {in->dims()[0],
sequence_width * context_length}; sequence_width * context_length};
...@@ -119,22 +111,17 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -119,22 +111,17 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
if (in_g || filter_g || (padding_trainable && padding_data_g)) { if (in_g || filter_g || (padding_trainable && padding_data_g)) {
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
// Because if padding_trainable is false, padding data should be zeros. // Because if padding_trainable is false, padding data should be zeros.
auto temp = framework::EigenVector<T>::Flatten(col); set_zero(context.device_context(), &col, static_cast<T>(0));
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
math::matmul<Place, T>(context.device_context(), *out_g, false, *filter, math::matmul<Place, T>(context.device_context(), *out_g, false, *filter,
true, T(1.0), &col, T(1.0)); true, T(1.0), &col, T(1.0));
} }
paddle::operators::math::SequenceProjectFunctor<Place, T> paddle::operators::math::ContextProjectFunctor<Place, T>
seq_project_functor; seq_project_functor;
if (in_g) { if (in_g) {
in_g->mutable_data<T>(context.GetPlace()); in_g->mutable_data<T>(context.GetPlace());
in_g->set_lod(in->lod()); in_g->set_lod(in->lod());
set_zero(context.device_context(), in_g, static_cast<T>(0));
math::SetConstant<Place, T> functor;
functor(context.device_context(), in_g, 0);
seq_project_functor(context.device_context(), *in_g, *padding_data_g, col, seq_project_functor(context.device_context(), *in_g, *padding_data_g, col,
padding_trainable, context_start, context_length, padding_trainable, context_start, context_length,
...@@ -143,9 +130,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -143,9 +130,7 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
if (padding_trainable && padding_data_g) { if (padding_trainable && padding_data_g) {
padding_data_g->mutable_data<T>(context.GetPlace()); padding_data_g->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), padding_data_g, static_cast<T>(0));
math::SetConstant<Place, T> functor;
functor(context.device_context(), padding_data_g, 0);
LoDTensor* input = const_cast<LoDTensor*>(in); LoDTensor* input = const_cast<LoDTensor*>(in);
seq_project_functor(context.device_context(), *input, *padding_data_g, seq_project_functor(context.device_context(), *input, *padding_data_g,
...@@ -155,12 +140,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -155,12 +140,10 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
if (filter_g) { if (filter_g) {
filter_g->mutable_data<T>(context.GetPlace()); filter_g->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), filter_g, static_cast<T>(0));
math::SetConstant<Place, T> functor; Tensor filter_grad = *filter_g;
functor(context.device_context(), filter_g, 0); LoDTensor out_grad = *out_g;
Tensor filter_grad_ = *filter_g;
LoDTensor out_grad_ = *out_g;
const Tensor* padding_data = nullptr; const Tensor* padding_data = nullptr;
if (padding_trainable) { if (padding_trainable) {
...@@ -177,11 +160,8 @@ class SequenceConvGradKernel : public framework::OpKernel<T> { ...@@ -177,11 +160,8 @@ class SequenceConvGradKernel : public framework::OpKernel<T> {
context_stride, up_pad, down_pad, false, false, context_stride, up_pad, down_pad, false, false,
false); false);
filter_grad_.Resize( math::matmul<Place, T>(context.device_context(), col, true, out_grad,
framework::make_ddim({context_length * sequence_width, 1})); false, T(1.0), &filter_grad, T(1.0));
math::matmul<Place, T>(context.device_context(), col, true, out_grad_,
false, T(1.0), &filter_grad_, T(1.0));
} }
} }
}; };
......
...@@ -20,8 +20,9 @@ class TestSeqProject(OpTest): ...@@ -20,8 +20,9 @@ class TestSeqProject(OpTest):
# one level, batch size # one level, batch size
x = np.random.uniform(0.1, 1, [self.input_size[0], x = np.random.uniform(0.1, 1, [self.input_size[0],
self.input_size[1]]).astype('float32') self.input_size[1]]).astype('float32')
w = np.random.uniform( w = np.random.uniform(0.1, 1, [
0.1, 1, [self.context_length, self.input_size[1]]).astype('float32') self.context_length * self.input_size[1], self.output_represention
]).astype('float32')
begin_pad = np.max([0, -self.context_start]) begin_pad = np.max([0, -self.context_start])
end_pad = np.max([0, self.context_start + self.context_length - 1]) end_pad = np.max([0, self.context_start + self.context_length - 1])
...@@ -49,7 +50,8 @@ class TestSeqProject(OpTest): ...@@ -49,7 +50,8 @@ class TestSeqProject(OpTest):
'padding_trainable': self.padding_trainable, 'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride 'context_stride': self.context_stride
} }
out = np.zeros((self.input_size[0], 1)).astype('float32') out = np.zeros(
(self.input_size[0], self.output_represention)).astype('float32')
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.compute() self.compute()
...@@ -95,13 +97,7 @@ class TestSeqProject(OpTest): ...@@ -95,13 +97,7 @@ class TestSeqProject(OpTest):
out[out_begin:out_end, j * self.input_size[1]:(j + 1) * out[out_begin:out_end, j * self.input_size[1]:(j + 1) *
self.input_size[1]] += in_sub self.input_size[1]] += in_sub
filter_dim = filter.shape
output_dim = self.outputs['Out'].shape
filter.shape = filter_dim[0] * filter_dim[1]
self.outputs['Out'].shape = (output_dim[0], )
np.dot(out, filter, out=self.outputs['Out']) np.dot(out, filter, out=self.outputs['Out'])
filter.shape = filter_dim
self.outputs['Out'].shape = output_dim
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
...@@ -166,6 +162,7 @@ class TestSeqProject(OpTest): ...@@ -166,6 +162,7 @@ class TestSeqProject(OpTest):
self.input_size = [self.input_row, 23] self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]] self.lod = [[0, 4, 5, 8, self.input_row]]
self.output_represention = 8 # output feature size
class TestSeqProjectCase1(TestSeqProject): class TestSeqProjectCase1(TestSeqProject):
...@@ -178,6 +175,7 @@ class TestSeqProjectCase1(TestSeqProject): ...@@ -178,6 +175,7 @@ class TestSeqProjectCase1(TestSeqProject):
self.input_size = [self.input_row, 23] self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]] self.lod = [[0, 4, 5, 8, self.input_row]]
self.output_represention = 8 # output feature size
class TestSeqProjectCase2(TestSeqProject): class TestSeqProjectCase2(TestSeqProject):
...@@ -193,6 +191,7 @@ class TestSeqProjectCase2(TestSeqProject): ...@@ -193,6 +191,7 @@ class TestSeqProjectCase2(TestSeqProject):
del idx[0] del idx[0]
self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() + self.lod = [[0] + np.sort(random.sample(idx, 8)).tolist() +
[self.input_size[0]]] [self.input_size[0]]]
self.output_represention = 8 # output feature size
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册