提交 834b82f1 编写于 作者: C chengduoZH

fix sequence_project_op forward and backward

上级 40688d22
......@@ -38,24 +38,23 @@ class SequenceProjectOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(
ctx->HasInput("PaddingData"),
"Output(PaddingData) of SequenceProjectOp should not be null.");
framework::DDim padding_dim = ctx->GetOutputDim("PaddingData");
framework::DDim padding_dim = ctx->GetInputDim("PaddingData");
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int total_pad = up_pad + down_pad;
int input_width = static_cast<int>(in_dims[1]);
if (context_start == 0 && context_length == 1) {
PADDLE_THROW(
"if context_start == 0 && context_length == 1, padding_trainable "
"should be false.");
}
PADDLE_ENFORCE(padding_dim.size() == 2,
"Input(PaddingData) should be 2-D tensor.");
PADDLE_ENFORCE(
padding_dim[0] == total_pad && padding_dim[1] == input_width,
"Input(PaddingData)'s shape is not consistent with 'context_start' "
"and 'context_length'.");
if (context_start == 0 && context_length == 1) {
PADDLE_THROW(
"if context_start == 0 && context_length == 1, padding_trainable "
"should be false.");
}
}
in_dims[1] = in_dims[1] * context_length;
......@@ -74,9 +73,11 @@ class SequenceProjectGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
if (ctx->Attrs().Get<bool>("padding_trainable")) {
PADDLE_ENFORCE(
ctx->HasOutput("PaddingData"),
"Output(PaddingData) of SequenceProjectOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("PaddingData")),
"Output(PaddingData@GRAD) of SequenceProjectGradOp should "
"not be null.");
auto padding_dims = ctx->GetInputDim("PaddingData");
ctx->SetOutputDim(framework::GradVarName("PaddingData"), padding_dims);
}
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
......@@ -93,8 +94,8 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput(
"Out",
"A float LoDTensor, the variable-length output of SequenceProjectOp.");
AddOutput("PaddingData",
"A float LoDTensor, the padding data of SequenceProjectOp.");
AddInput("PaddingData", // PaddingData can be a float tensor
"A float LoDTensor, the padding data of SequenceProjectOp.");
AddAttr<bool>("padding_trainable",
"(bool, default false) the padding data of SequenceProjectOp "
......@@ -110,7 +111,8 @@ class SequenceProjectOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("context_stride",
"(int, default 1) the xx of SequenceProjectOp.")
.SetDefault(1)
.GreaterThan(0);
.GreaterThan(
0); // Currently, sequence_project_op only support context_stride=1
AddComment(R"DOC(
SequenceProjectOp projects features of context_length time-steps of each instance.
......
......@@ -23,6 +23,9 @@ namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
......@@ -34,6 +37,13 @@ class SequenceProjectKernel : public framework::OpKernel<T> {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
out->mutable_data<T>(context.GetPlace());
// need discuss, is it necessary to set zeros ?
// Because if padding_trainable is false, padding data should be zeros.
auto temp = framework::EigenVector<T>::Flatten(*out);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
auto place = context.GetEigenDevice<Place>();
int context_start = context.Attr<int>("context_start");
......@@ -45,10 +55,10 @@ class SequenceProjectKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now.");
auto lod_level_0 = in->lod()[0];
int64_t input_stride = in->dims()[1];
int64_t output_stride = out->dims()[1];
int64_t padding_stride = 0;
PADDLE_ENFORCE(input_stride * context_length == output_stride,
int64_t input_width = in->dims()[1];
int64_t output_width = out->dims()[1];
int64_t padding_width = 0;
PADDLE_ENFORCE(input_width * context_length == output_width,
"Input size and pooling size should be consistent.");
const LoDTensor* padding_data = nullptr;
......@@ -56,73 +66,105 @@ class SequenceProjectKernel : public framework::OpKernel<T> {
padding_data = context.Input<LoDTensor>("PaddingData");
PADDLE_ENFORCE_EQ(padding_data->dims().size(), 2UL,
"Only support one level sequence now.");
padding_stride = padding_data->dims()[1];
PADDLE_ENFORCE(padding_stride == input_stride,
padding_width = padding_data->dims()[1];
PADDLE_ENFORCE(padding_width == input_width,
"Input size and pooling size should be consistent.");
}
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_height, sequence_width;
int input_row_begin, input_row_end;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
for (int i = 0; i < static_cast<int>(lod_level_0.size()) - 1; ++i) {
Tensor in_t = in->Slice<T>(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
input_row_begin = (context_start > 0)
? static_cast<int>(lod_level_0[i]) + context_start
: static_cast<int>(lod_level_0[i]);
input_row_end = static_cast<int>(lod_level_0[i + 1]);
Tensor out_t = out->Slice<T>(static_cast<int>(lod_level_0[i]),
static_cast<int>(lod_level_0[i + 1]));
int sequence_height = in_t.dims()[0];
int sequence_width = in_t.dims()[1];
sequence_height = static_cast<int>(out_t.dims()[0]);
sequence_width = static_cast<int>(in->dims()[1]);
std::vector<int64_t> output_shape(
{sequence_height, 1, 1, context_length,
sequence_width}); // output_height, output_width,
// input_channels,
// filter_height, filter_width
// input_channels, filter_height, filter_width
out_t.Resize(framework::make_ddim(output_shape));
std::vector<int64_t> input_shape(
{1, sequence_height,
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
for (int j = 0; j < context_length; ++j) {
if (input_row_begin < input_row_end) {
Tensor in_t = in->Slice<T>(input_row_begin, input_row_end);
std::vector<int64_t> input_shape(
{1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
im2col_ocf(context.device_context(), in_t, out_t,
/*stride_height*/ context_stride, /*stride_width*/ 0, up_pad,
down_pad);
if (padding_trainable) {
// add up trainable data
out_t.Resize(framework::make_ddim(
{sequence_height * context_length, sequence_width}));
if (up_pad != 0) {
for (int k = 0; k < up_pad; ++k) {
Tensor out_t_sub = out_t.Slice<T>(
k * context_length, k * context_length + (up_pad - k));
Tensor w_sub = padding_data->Slice<T>(k, context_length - k);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(place) = w_sub_e;
}
}
if (padding_trainable) {
// add up trainable data
out_t.Resize(framework::make_ddim(
{sequence_height * context_length, sequence_width}));
if (up_pad > 0) { // add up pad
int padding_rows = std::min(
up_pad, static_cast<int>(lod_level_0[i + 1] - lod_level_0[i]));
for (int k = 0; k < padding_rows; ++k) {
int padding_size =
k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_t.Slice<T>(
k * context_length, k * context_length + padding_size);
Tensor w_sub = padding_data->Slice<T>(k, k + padding_size);
// in this block, using EigenVector<T>::Flatten is ok too.
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(place) = w_sub_e;
}
if (down_pad != 0) {
int k =
(sequence_height + up_pad - context_length) / context_stride +
1;
for (int t = 0; t + k < sequence_height; ++t) {
Tensor out_t_sub =
out_t.Slice<T>((k + t) * context_length * sequence_width -
t * sequence_width,
(k + t) * context_length * sequence_width);
Tensor w_sub = padding_data->Slice<T>(up_pad + 1, up_pad + 1 + t);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(place) = w_sub_e;
}
if (down_pad > 0) { // add down pad
int down_pad_begin_row =
std::max(0,
(sequence_height - context_start - context_length) + 1) +
1;
int padding_begin = std::max(0, context_start - sequence_height);
int padding_size =
sequence_height - context_start >= context_length
? 1
: context_length - (sequence_height - context_start);
if (context_start >= sequence_height) padding_size = context_length;
int padding_idx = padding_begin;
for (int t = 0; t + down_pad_begin_row <= sequence_height;
++t, ++padding_size) {
if (context_start >= sequence_height) padding_size = context_length;
if (padding_size > context_length) {
padding_size = context_length;
padding_idx++;
}
if (padding_begin > 0 || sequence_height == context_start)
padding_idx = padding_begin + t;
Tensor out_t_sub = out_t.Slice<T>(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data->Slice<T>(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
out_t_sub_e.device(place) = w_sub_e;
}
out_t.Resize(framework::make_ddim(
{sequence_height, context_length * sequence_width}));
}
}
out_t.Resize(framework::make_ddim(
{sequence_height, context_length * sequence_width}));
}
}
};
......@@ -131,95 +173,136 @@ template <typename Place, typename T>
class SequenceProjectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// auto* in = context.Input<LoDTensor>("X");
auto* out_g = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_g = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* in = context.Input<LoDTensor>("X");
in_g->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
int context_start = context.Attr<int>("context_start");
int context_length = context.Attr<int>("context_length");
bool padding_trainable = context.Attr<bool>("padding_trainable");
int context_stride = context.Attr<bool>("context_stride");
int context_stride = context.Attr<int>("context_stride");
// InferShape by in_lod
PADDLE_ENFORCE_EQ(in_g->lod().size(), 1UL,
PADDLE_ENFORCE_EQ(in->lod().size(), 1UL,
"Only support one level sequence now.");
auto lod_g_level_0 = in_g->lod()[0];
auto lod_g_level_0 = in->lod()[0];
int64_t input_width = in_g->dims()[1];
int64_t output_width = out_g->dims()[1];
int64_t padding_width = 0;
PADDLE_ENFORCE(input_width * context_length == output_width,
"Input size and pooling size should be consistent.");
LoDTensor* padding_data = nullptr;
LoDTensor* padding_data_g = nullptr;
if (padding_trainable) {
padding_data = context.Output<LoDTensor>("PaddingData");
padding_data->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ(padding_data->dims().size(), 2UL,
padding_data_g =
context.Output<LoDTensor>(framework::GradVarName("PaddingData"));
padding_data_g->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_EQ(padding_data_g->dims().size(), 2UL,
"Only support one level sequence now.");
padding_width = padding_data->dims()[1];
padding_width = padding_data_g->dims()[1];
PADDLE_ENFORCE(padding_width == input_width,
"Input size and pooling size should be consistent.");
}
int up_pad = std::max(0, -context_start);
int down_pad = std::max(0, context_start + context_length - 1);
int sequence_height, sequence_width;
int input_row_begin, input_row_end;
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
col2im_ocf;
for (int i = 0; i < static_cast<int>(lod_g_level_0.size()) - 1; ++i) {
Tensor in_g_t = in_g->Slice<T>(static_cast<int>(lod_g_level_0[i]),
static_cast<int>(lod_g_level_0[i + 1]));
input_row_begin = (context_start > 0)
? static_cast<int>(lod_g_level_0[i]) + context_start
: static_cast<int>(lod_g_level_0[i]);
input_row_end = static_cast<int>(lod_g_level_0[i + 1]);
Tensor out_g_t = out_g->Slice<T>(static_cast<int>(lod_g_level_0[i]),
static_cast<int>(lod_g_level_0[i + 1]));
int sequence_height = in_g_t.dims()[0];
int sequence_width = in_g_t.dims()[1];
for (int j = 0; j < context_length; ++j) {
if (padding_trainable) {
out_g_t.Resize(framework::make_ddim(
{sequence_height * context_length, sequence_width}));
if (up_pad != 0) {
for (int k = 0; k < up_pad; ++k) {
Tensor out_t_sub = out_g_t.Slice<T>(
k * context_length, k * context_length + (up_pad - k));
Tensor w_sub = padding_data->Slice<T>(k, context_length - k);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(place) = w_sub_e + out_t_sub_e;
// out_t_sub_e.device(place) = 0;
}
sequence_height = static_cast<int>(out_g_t.dims()[0]);
sequence_width = static_cast<int>(in_g->dims()[1]);
if (padding_trainable) {
// add up trainable data
out_g_t.Resize(framework::make_ddim(
{sequence_height * context_length, sequence_width}));
if (up_pad > 0) { // add up pad
int padding_rows = std::min(
up_pad,
static_cast<int>(lod_g_level_0[i + 1] - lod_g_level_0[i]));
for (int k = 0; k < padding_rows; ++k) {
int padding_size =
k + context_length < up_pad ? context_length : up_pad - k;
Tensor out_t_sub = out_g_t.Slice<T>(
k * context_length, k * context_length + padding_size);
Tensor w_sub = padding_data_g->Slice<T>(k, k + padding_size);
// in this block, using EigenVector<T>::Flatten is ok too.
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(place) = w_sub_e + out_t_sub_e;
}
if (down_pad != 0) {
int k =
(sequence_height + up_pad - context_length) / context_stride +
1;
for (int t = 0; t + k < sequence_height; ++t) {
Tensor out_t_sub =
out_g_t.Slice<T>((k + t) * context_length * sequence_width -
t * sequence_width,
(k + t) * context_length * sequence_width);
Tensor w_sub = padding_data->Slice<T>(up_pad + 1, up_pad + 1 + t);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(place) = w_sub_e + out_t_sub_e;
// out_t_sub_e.device(place) = 0;
}
if (down_pad > 0) { // add down pad
int down_pad_begin_row =
std::max(0,
(sequence_height - context_start - context_length) + 1) +
1;
int padding_begin = std::max(0, context_start - sequence_height);
int padding_size =
sequence_height - context_start >= context_length
? 1
: context_length - (sequence_height - context_start);
if (context_start >= sequence_height) padding_size = context_length;
int padding_idx = padding_begin;
for (int t = 0; t + down_pad_begin_row <= sequence_height;
++t, ++padding_size) {
if (context_start >= sequence_height) padding_size = context_length;
if (padding_size > context_length) {
padding_size = context_length;
padding_idx++;
}
if (padding_begin > 0 || sequence_height == context_start)
padding_idx = padding_begin + t;
Tensor out_t_sub = out_g_t.Slice<T>(
(down_pad_begin_row + t) * context_length - padding_size,
(down_pad_begin_row + t) * context_length);
Tensor w_sub = padding_data_g->Slice<T>(
up_pad + padding_idx, up_pad + padding_idx + padding_size);
auto out_t_sub_e = EigenMatrix<T>::From(out_t_sub);
auto w_sub_e = EigenMatrix<T>::From(w_sub);
w_sub_e.device(place) = w_sub_e + out_t_sub_e;
}
}
out_g_t.Resize(framework::make_ddim(
{sequence_height, 1, 1, context_length, sequence_width}));
}
if (in && input_row_begin < input_row_end) {
Tensor in_t = in_g->Slice<T>(input_row_begin, input_row_end);
col2im_ocf(context.device_context(), in_g_t, out_g_t,
std::vector<int64_t> output_shape(
{sequence_height, 1, 1, context_length,
sequence_width}); // output_height, output_width,
// input_channels, filter_height, filter_width
out_g_t.Resize(framework::make_ddim(output_shape));
std::vector<int64_t> input_shape(
{1, input_row_end - input_row_begin,
sequence_width}); // input_channels, input_height, input_width
in_t.Resize(framework::make_ddim(input_shape));
col2im_ocf(context.device_context(), in_t, out_g_t,
/*stride_height*/ context_stride, /*stride_width*/ 0, up_pad,
down_pad);
// out_g_t back to orign size
}
out_g_t.Resize(framework::make_ddim(
{sequence_height, context_length * sequence_width}));
}
}
};
......
import unittest
import numpy as np
import random
from op_test import OpTest
......@@ -10,18 +11,22 @@ class TestSeqProject(OpTest):
# one level, batch size
x = np.random.uniform(
0.1, 1, [self.input_size[0], self.input_size[1]]).astype('float32')
lod = [[0, 4, 5, 8, self.input_size[0]]]
self.begin_pad = np.max([0, -self.context_start])
self.end_pad = np.max([0, self.context_start + self.context_length - 1])
self.total_pad = self.begin_pad + self.end_pad
w = np.ones((self.total_pad, self.input_size[1])) * 100
self.inputs = {'X': (x, lod), 'PaddingData': w}
# w = np.ones((self.total_pad, self.input_size[1])) * 100
w = np.array(range(self.total_pad * self.input_size[1]))
w.shape = self.total_pad, self.input_size[1]
self.inputs = {
'X': (x, self.lod),
'PaddingData': (w, [[0, self.total_pad]])
}
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable
'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride
}
out = np.zeros((self.input_size[0], self.input_size[1] *
self.context_length)).astype('float32')
......@@ -30,9 +35,10 @@ class TestSeqProject(OpTest):
def compute(self):
x, lod = self.inputs['X']
w = self.inputs['PaddingData']
w, _ = self.inputs['PaddingData']
out = self.outputs['Out']
lod = lod[0]
begin_pad = np.max([0, -self.context_start])
for i in range(len(lod) - 1):
for j in range(self.context_length):
......@@ -43,22 +49,20 @@ class TestSeqProject(OpTest):
if in_begin < lod[i]:
pad_size = np.min([lod[i] - in_begin, lod[i + 1] - lod[i]])
if self.padding_trainable:
sub_w = w[j:pad_size, :]
sub_w = w[j:j + pad_size, :]
out[lod[i]:lod[i] + pad_size, j * self.input_size[1]:(
j + 1) * self.input_size[1]] = sub_w
# pass
out_begin = lod[i] + pad_size
in_begin = lod[i]
if in_end > lod[i + 1]:
pad_size = np.min(
[in_end - lod[i + 1], lod[i + 1] - lod[i]])
out_sub = out[lod[i + 1] - pad_size:lod[i + 1], :]
if self.padding_trainable:
sub_w = w[j - pad_size:j, :]
sub_w = w[begin_pad + self.context_start + j - pad_size:
begin_pad + self.context_start + j, :]
out[lod[i + 1] - pad_size:lod[i + 1], j * self.
input_size[1]:(j + 1) * self.input_size[1]] = sub_w
# pass
in_end = lod[i + 1]
out_end = lod[i + 1] - pad_size
if in_end <= in_begin:
......@@ -69,28 +73,105 @@ class TestSeqProject(OpTest):
self.input_size[1]] += in_sub
def init_test_case(self):
self.input_size = [11, 23]
self.input_row = 11
self.input_size = [self.input_row, 23]
self.lod = [[0, 4, 5, 8, self.input_row]]
self.op_type = "sequence_project"
self.context_start = -1
self.context_length = 3
self.padding_trainable = False
self.padding_trainable = True
self.context_stride = 1
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
# self.check_grad(
# set(['X', 'PaddingData']), 'Out', max_relative_error=0.05)
# class TestSeqAvgPool2D(TestSeqProject):
# def init_test_case(self):
# self.input_size = [11, 23]
# self.op_type = "sequence_project"
# def test_check_grad_no_filter(self):
# self.check_grad(
# ['X'],
# 'Out',
# max_relative_error=0.05,
# no_grad_set=set(['PaddingData']))
#
# self.context_start = -1
# self.context_length = 3
# self.padding_trainable = True
# def test_check_grad_no_input(self):
# self.check_grad(
# ['PaddingData'],
# 'Out',
# max_relative_error=0.05,
# no_grad_set=set(['X']))
'''
class TestSeqProjectCases(TestSeqProject):
def setUp(self):
self.init_test_case()
self.op_type = 'sequence_project'
num = 0
for context_start in [-5, -3, -1, 0, 3]:
for context_length in [1, 2, 5, 7]:
for batch_size in [1, 2, 5, 7]:
for padding_trainable in [False, True]:
if context_length == 1 and context_start == 0 and padding_trainable:
continue
self.context_start = context_start
self.context_length = context_length
self.padding_trainable = padding_trainable
self.input_size = [batch_size, 23]
x = np.random.uniform(0.1, 1,
self.input_size).astype('float32')
self.lod = [[0, self.input_size[0]]]
if self.input_size[0] > 2:
idx = range(self.input_size[0])
del idx[0]
self.lod = [
[0] + np.sort(random.sample(idx, 2)).tolist() +
[self.input_size[0]]
]
self.begin_pad = np.max([0, -self.context_start])
self.end_pad = np.max(
[0, self.context_start + self.context_length - 1])
self.total_pad = self.begin_pad + self.end_pad
# w = np.ones((self.total_pad, self.input_size[1])) * 100
w = np.array(range(self.total_pad * self.input_size[1]))
w.shape = self.total_pad, self.input_size[1]
if self.total_pad * self.input_size[1] == 0:
w = np.random.uniform(
0.1, 1,
(1, self.input_size[1])).astype('float32')
self.total_pad = 1
self.inputs = {
'X': (x, self.lod),
'PaddingData': (w, [[0, self.total_pad]])
}
self.attrs = {
'context_start': self.context_start,
'context_length': self.context_length,
'padding_trainable': self.padding_trainable,
'context_stride': self.context_stride
}
out = np.zeros((self.input_size[0], self.input_size[1] *
self.context_length)).astype('float32')
self.outputs = {'Out': out}
print num
print self.attrs
print batch_size
print padding_trainable
print "$$$$$$$$$$$$$"
self.compute()
self.test_check_output()
num += 1
'''
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册