未验证 提交 6fcfd32e 编写于 作者: Y Yiqun Liu 提交者: GitHub

Check and correct the output's lod_level in DynamicRNN related operators (#19144)

* Refine the InferShape of ReadFrom and WriteTo op, and add comment to explain why not call ShareLoD for runtime.
test=develop

* Add comment for ReorderLoDTensorByRank op.

* Add comment for lod_tensor_to_tensor_array op to explain why only call DecreaseLoDLevel for compile time.
test=develop

* ShrinkRNNMemory op should call ShareLoD for compile time.
test=develop

* Add the implementation of IncreaseLoDLevel and add the compile-time check of lod_level in InferShape of sequence_pool.
test=develop

* Refine the unittest of DynamicRNN.
test=develop

* Change PADDLE_ENFORCE to PADDLE_ENFORCE_NE.
test=develop
上级 b5f3be83
......@@ -86,7 +86,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
VLOG(3) << "input " << in << " is not LodTensor or LodTensorArray.";
VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
return;
}
out_var->SetLoDLevel(in_var->GetLoDLevel());
......@@ -94,6 +94,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const override {
// When in is a LoDTensor and out is a LoDTensorArray, there may need to
// decrease the lod_level.
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
......@@ -102,17 +104,35 @@ class CompileTimeInferShapeContext : public InferShapeContext {
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
PADDLE_ENFORCE(out_var->GetType() == proto::VarType::LOD_TENSOR_ARRAY ||
out_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensorArray or LodTensor.",
out_var->Name());
PADDLE_ENFORCE(in_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensor.", in_var->Name());
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR,
"The input %s should be LoDTensor.", in_var->Name());
PADDLE_ENFORCE_EQ(out_var->GetType(), proto::VarType::LOD_TENSOR_ARRAY,
"The output %s should be LoDTensorArray.",
out_var->Name());
if (in_var->GetLoDLevel() > 0) {
out_var->SetLoDLevel(in_var->GetLoDLevel() - 1);
}
}
void IncreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const override {
// When in is a LoDTensorArray and out is a LoDTensor, there may need to
// increase the lod_level.
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
PADDLE_ENFORCE_NE(Inputs(in)[i], framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", in, i);
PADDLE_ENFORCE_NE(Outputs(out)[j], framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarType::LOD_TENSOR_ARRAY,
"The input %s should be LoDTensorArray.", in_var->Name());
PADDLE_ENFORCE_EQ(out_var->GetType(), proto::VarType::LOD_TENSOR,
"The output %s should be LoDTensor.", out_var->Name());
out_var->SetLoDLevel(in_var->GetLoDLevel() + 1);
}
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Inputs(name);
......
......@@ -657,7 +657,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
void DecreaseLoDLevel(const std::string& in, const std::string& out,
size_t i = 0, size_t j = 0) const override {
PADDLE_THROW("DecreaseLoDLevel is only used in compile time.");
PADDLE_THROW(
"DecreaseLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel.");
}
void IncreaseLoDLevel(const std::string& in, const std::string& out,
size_t i = 0, size_t j = 0) const override {
PADDLE_THROW(
"IncreaseLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel.");
}
bool IsRuntime() const override { return true; }
......
......@@ -68,6 +68,9 @@ class InferShapeContext {
virtual void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
virtual void IncreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
virtual bool IsRuntime() const = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
......
......@@ -192,7 +192,21 @@ class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
"ArrayToLoDTensorOp must has input X.");
PADDLE_ENFORCE(context->HasInput("RankTable"),
"ArrayToLoDTensorOp must has input RankTable.");
// For compile-time, the first dim of input X and output Out should be -1.
// For runtime, the first dim of output Out should be the sum of all
// elements's first dim in input X. The output's dims will be re-computed in
// detail kernel implementation.
context->SetOutputDim("Out", context->GetInputDim("X"));
// The output LoDTensor's lod_level should be input X's lod_level + 1.
// For compile-time, we call IncreaseLoDLevel to set output's lod_level.
// For runtime, output LoDTensor's lod is determined by input X's lod and
// the level specified by input RandTable.
// We cannot get X's detail lod and RankTable's level in this function, so
// leave this work to the detail kernel implementation.
if (!context->IsRuntime()) {
context->IncreaseLoDLevel("X", /*->*/ "Out");
}
}
};
......
......@@ -88,8 +88,21 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
if (!context->HasInput("X")) {
return;
}
PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError());
context->SetOutputDim("Out", context->GetInputDim("X"));
// When compile time, we need to:
// - for ReadFromArray, share tensor_array X's lod_level to Out
// - for WriteToArray, share X's lod_level to tensor_array Out
// When runtime, we need to:
// - for ReadFromArray, share X[I]'s lod to Out
// - for WriteToArray, share X's lod to Out[I]
// but we cannot get I's value here, so leave this work to detail
// kernel implementation.
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
}
protected:
......@@ -166,19 +179,6 @@ $$T = A[i]$$
};
class ReadFromArrayInferShape : public WriteToArrayInferShape {
public:
void operator()(framework::InferShapeContext *context) const override {
WriteToArrayInferShape::operator()(context);
if (!context->HasInput("X")) {
return;
}
// FIXME: just for compile time.
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
}
protected:
const char *NotHasXError() const override {
return "The input array X must be set";
......
......@@ -106,9 +106,10 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
auto max_seq_len = items[0].length;
auto rank_level = rank_table.level();
PADDLE_ENFORCE_LT(rank_level, x.lod().size(),
"Input should be a LOD tensor, and size is at least %d",
rank_level + 1);
PADDLE_ENFORCE_LT(
rank_level, x.lod().size(),
"Input should be a LoDTensor, and its lod_level should be at least %d",
rank_level + 1);
out.resize(max_seq_len);
std::vector<std::vector<CopyRange>> copy_ranges(max_seq_len);
......@@ -167,10 +168,21 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
class LoDTensorToArrayOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "");
AddInput("RankTable", "");
AddOutput("Out", "");
AddComment("");
AddInput("X",
"(LoDTensor), the input lod tensor is a minibatch of sequences, "
"and will be split to a tensor_array according to "
"Input(RankTable).");
AddInput("RankTable", "(LoDRankTable), the rank table.");
AddOutput("Out",
"(LoDTensorArray), the result tensor_array, which is actually a "
"std::vector<LoDTensor>.");
AddComment(R"DOC(LoDTensorToArray operator.
Input(X) is a minibatch of sequences. Input(RankTable) stores the order of the input sequences.
The lod_tensor_to_array operator will spilt the input sequences to a tensor_array, with each
element stores one sequence, according to the input rank_table.
NOTE: this operator is an internal component of DynamicRNN, and cannot be called by users.
)DOC");
}
};
......@@ -187,10 +199,18 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
"Output(Out) of LoDTensorToArrayOp should not be null.");
auto x_dim = context->GetInputDim("X");
// The first dim of each LoDTensor in Output can only be set at run-time.;
// We still have to Resize each LoDTensor in Output.
// For compile-time, the first dim of input X and output Out should be -1.
// For runtime, the first dim of input X should be the sum of all elements's
// first dim in output Out. The output's dims will be re-computed in detail
// kernel implementation.
context->SetOutputDim("Out", x_dim);
// The lod level should be passed to out in compile time.
// The output LoDTensor's lod_level should be input X's lod_level - 1.
// For compile time, we call DecreaseLoDLevel to set output's lod_level.
// For runtime, output LoDTensor's lod is determined by input X's lod and
// the level specified by input RandTable.
// We cannot get X's detail lod and RankTable's level in this function, so
// leave this work to the detail kernel implementation.
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
}
......
......@@ -202,6 +202,9 @@ class IdentityInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim("Out", context->GetInputDim("X"));
// X'lod and Out'lod is different on runtime, so there is no need to call
// ShareLoD for runtime. While the setting of Out's lod is done in detail
// kernel implementation.
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
......
......@@ -28,6 +28,16 @@ class SequencePoolOp : public framework::OperatorWithKernel {
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of SequencePoolOp should not be null.");
if (!ctx->IsRuntime()) {
// Check the lod_level for compile-time.
framework::VarDesc* x_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("X")[0]);
PADDLE_ENFORCE_GT(
x_desc->GetLoDLevel(), 0,
"The LoD level Input(X) of sequence_pool should be larger than 0");
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (ctx->Attrs().Get<std::string>("pooltype") == "MAX") {
PADDLE_ENFORCE_EQ(
......
......@@ -100,8 +100,10 @@ class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X"));
// For runtime, output's lod is computed according to input's lod, but
// remove the finished sequence. It is set in detail kernel implementation.
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
context->ShareLoD("X", /*->*/ "Out");
}
}
};
......
......@@ -27,13 +27,53 @@ from paddle.fluid.layers.control_flow import shrink_memory
from fake_reader import fake_imdb_reader
class TestDynRNN(unittest.TestCase):
class TestDynamicRNN(unittest.TestCase):
def setUp(self):
self.word_dict_len = 5147
self.BATCH_SIZE = 2
reader = fake_imdb_reader(self.word_dict_len, self.BATCH_SIZE * 100)
self.train_data = paddle.batch(reader, batch_size=self.BATCH_SIZE)
def _train(self,
main_program,
startup_program,
feed_list,
fetch_list,
is_nested=False,
max_iters=1):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
data = next(self.train_data())
for iter_id in range(max_iters):
fetch_outs = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=fetch_list,
return_numpy=False)
if len(fetch_list) == 3:
rnn_in_seq = fetch_outs[0]
rnn_out_seq = fetch_outs[1]
if not is_nested:
# Check for lod set in runtime. When lod_level is 1,
# the lod of DynamicRNN's output should be the same as input.
self.assertEqual(rnn_in_seq.lod(), rnn_out_seq.lod())
loss_i = numpy.array(fetch_outs[2])
elif len(fetch_list) == 1:
loss_i = numpy.array(fetch_outs[0])
#print(loss_i)
self.assertEqual((1, ), loss_i.shape)
self.assertFalse(numpy.isnan(loss_i))
if iter_id == 0:
loss_0 = loss_i
if max_iters > 10:
# loss should be small after 10 mini-batch
self.assertLess(loss_i[0], loss_0[0])
def test_plain_while_op(self):
main_program = fluid.Program()
startup_program = fluid.Program()
......@@ -44,10 +84,7 @@ class TestDynRNN(unittest.TestCase):
sent_emb = fluid.layers.embedding(
input=sentence, size=[self.word_dict_len, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='float32')
rank_table = lod_rank_table(x=sent_emb)
sent_emb_array = lod_tensor_to_array(x=sent_emb, table=rank_table)
seq_len = max_sequence_len(rank_table=rank_table)
......@@ -61,7 +98,6 @@ class TestDynRNN(unittest.TestCase):
shape=[-1, 100],
dtype='float32')
boot_mem.stop_gradient = False
mem_array = fluid.layers.array_write(x=boot_mem, i=i)
cond = fluid.layers.less_than(x=i, y=seq_len)
......@@ -82,27 +118,29 @@ class TestDynRNN(unittest.TestCase):
fluid.layers.array_write(x=hidden, i=i, array=mem_array)
fluid.layers.less_than(x=i, y=seq_len, cond=cond)
all_timesteps = array_to_lod_tensor(x=out, table=rank_table)
last = fluid.layers.sequence_last_step(input=all_timesteps)
result_all_timesteps = array_to_lod_tensor(x=out, table=rank_table)
last = fluid.layers.sequence_last_step(input=result_all_timesteps)
logits = fluid.layers.fc(input=last, size=1, act=None)
label = fluid.layers.data(name='label', shape=[1], dtype='float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
loss = fluid.layers.mean(loss)
sgd = fluid.optimizer.SGD(1e-4)
sgd.minimize(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(startup_program)
feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu)
data = next(self.train_data())
val = exe.run(main_program, feed=feeder.feed(data),
fetch_list=[loss])[0]
self.assertEqual((1, ), val.shape)
print(val)
self.assertFalse(numpy.isnan(val))
# Check for lod_level set in compile-time.
self.assertEqual(sent_emb.lod_level, result_all_timesteps.lod_level)
self._train(
main_program=main_program,
startup_program=startup_program,
feed_list=[sentence, label],
fetch_list=[sent_emb, result_all_timesteps, loss],
is_nested=False,
max_iters=1)
def test_train_dyn_rnn(self):
def test_train_dynamic_rnn(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
......@@ -111,17 +149,18 @@ class TestDynRNN(unittest.TestCase):
sent_emb = fluid.layers.embedding(
input=sentence, size=[self.word_dict_len, 32], dtype='float32')
rnn = fluid.layers.DynamicRNN()
with rnn.block():
in_ = rnn.step_input(sent_emb)
mem = rnn.memory(shape=[100], dtype='float32')
drnn = fluid.layers.DynamicRNN()
with drnn.block():
in_ = drnn.step_input(sent_emb)
mem = drnn.memory(shape=[100], dtype='float32')
out_ = fluid.layers.fc(input=[in_, mem], size=100, act='tanh')
rnn.update_memory(mem, out_)
rnn.output(out_)
drnn.update_memory(mem, out_)
drnn.output(out_)
last = fluid.layers.sequence_last_step(input=rnn())
drnn_result = drnn()
last = fluid.layers.sequence_last_step(input=drnn_result)
logits = fluid.layers.fc(input=last, size=1, act=None)
label = fluid.layers.data(name='label', shape=[1], dtype='float32')
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
......@@ -129,37 +168,30 @@ class TestDynRNN(unittest.TestCase):
sgd = fluid.optimizer.Adam(1e-3)
sgd.minimize(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(startup_program)
feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu)
data = next(self.train_data())
loss_0 = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])[0]
for _ in range(100):
val = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])[0]
# loss should be small after 100 mini-batch
self.assertLess(val[0], loss_0[0])
# Check for lod_level set in compile-time.
self.assertEqual(sent_emb.lod_level, drnn_result.lod_level)
# this unit test is just used to the two layer nested dyn_rnn.
def test_train_nested_dyn_rnn(self):
word_dict = [i for i in range(30)]
self._train(
main_program=main_program,
startup_program=startup_program,
feed_list=[sentence, label],
fetch_list=[sent_emb, drnn_result, loss],
is_nested=False,
max_iters=100)
def fake_reader():
seq_len, label = [[2, 2]], [0, 1]
data = []
for ele in seq_len:
for j in ele:
data.append([numpy.random.randint(30) \
for _ in range(j)])
def _fake_reader(self):
seq_len, label = [[2, 2]], [0, 1]
data = []
for ele in seq_len:
for j in ele:
data.append([numpy.random.randint(30) for _ in range(j)])
while True:
yield data, label
while True:
yield data, label
train_data = paddle.batch(fake_reader, batch_size=2)
# this unit test is just used to the two layer nested dyn_rnn.
def test_train_nested_dynamic_rnn(self):
word_dict = [i for i in range(30)]
main_program = fluid.Program()
startup_program = fluid.Program()
......@@ -169,63 +201,50 @@ class TestDynRNN(unittest.TestCase):
label = fluid.layers.data(
name='label', shape=[1], dtype='float32', lod_level=1)
rnn = fluid.layers.DynamicRNN()
with rnn.block():
in_ = rnn.step_input(sentence)
assert in_.lod_level == 1, "the lod level of in_ should be 1"
sent_emb = fluid.layers.embedding(
input=in_, size=[len(word_dict), 32], dtype='float32')
out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh')
rnn1 = fluid.layers.DynamicRNN()
with rnn1.block():
in_1 = rnn1.step_input(out_)
drnn0 = fluid.layers.DynamicRNN()
with drnn0.block():
in_0 = drnn0.step_input(sentence)
assert in_0.lod_level == 1, "the lod level of in_ should be 1"
sentence_emb = fluid.layers.embedding(
input=in_0, size=[len(word_dict), 32], dtype='float32')
out_0 = fluid.layers.fc(input=sentence_emb,
size=100,
act='tanh')
drnn1 = fluid.layers.DynamicRNN()
with drnn1.block():
in_1 = drnn1.step_input(out_0)
assert in_1.lod_level == 0, "the lod level of in_1 should be 0"
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
rnn1.output(out_1)
drnn1.output(out_1)
last = fluid.layers.sequence_last_step(input=rnn1())
rnn.output(last)
drnn1_result = drnn1()
last_1 = fluid.layers.sequence_last_step(input=drnn1_result)
drnn0.output(last_1)
last = rnn()
last = drnn0()
logits = fluid.layers.fc(input=last, size=1, act=None)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
loss = fluid.layers.mean(loss)
sgd = fluid.optimizer.SGD(1e-3)
#sgd = fluid.optimizer.Adam(1e-3)
sgd.minimize(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(startup_program)
feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu)
data = next(train_data())
val = exe.run(main_program, feed=feeder.feed(data),
fetch_list=[loss])[0]
for _ in range(100):
val = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])[0]
print(val)
train_data_orig = self.train_data
self.train_data = paddle.batch(self._fake_reader, batch_size=2)
self._train(
main_program=main_program,
startup_program=startup_program,
feed_list=[sentence, label],
fetch_list=[loss],
is_nested=True,
max_iters=100)
self.train_data = train_data_orig
# this unit test is just used to the two layer nested dyn_rnn.
def test_train_nested_dyn_rnn2(self):
def test_train_nested_dynamic_rnn2(self):
word_dict = [i for i in range(30)]
def fake_reader():
seq_len, label = [[2, 2]], [0, 1]
data = []
for ele in seq_len:
for j in ele:
data.append([numpy.random.randint(30) \
for _ in range(j)])
while True:
yield data, label
train_data = paddle.batch(fake_reader, batch_size=2)
hidden_size = 32
main_program = fluid.Program()
startup_program = fluid.Program()
......@@ -235,14 +254,14 @@ class TestDynRNN(unittest.TestCase):
label = fluid.layers.data(
name='label', shape=[1], dtype='float32', lod_level=1)
rnn = fluid.layers.DynamicRNN()
with rnn.block():
in_ = rnn.step_input(sentence)
sent_emb = fluid.layers.embedding(
input=in_,
drnn0 = fluid.layers.DynamicRNN()
with drnn0.block():
in_0 = drnn0.step_input(sentence)
sentence_emb = fluid.layers.embedding(
input=in_0,
size=[len(word_dict), hidden_size],
dtype='float32')
input_forward_proj = fluid.layers.fc(input=sent_emb,
input_forward_proj = fluid.layers.fc(input=sentence_emb,
size=hidden_size * 4,
act=None,
bias_attr=False)
......@@ -251,36 +270,33 @@ class TestDynRNN(unittest.TestCase):
size=hidden_size * 4,
use_peepholes=False)
rnn1 = fluid.layers.DynamicRNN()
with rnn1.block():
in_1 = rnn1.step_input(forward)
drnn1 = fluid.layers.DynamicRNN()
with drnn1.block():
in_1 = drnn1.step_input(forward)
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
rnn1.output(out_1)
drnn1.output(out_1)
last = fluid.layers.sequence_last_step(input=rnn1())
rnn.output(last)
last = fluid.layers.sequence_last_step(input=drnn1())
drnn0.output(last)
last = rnn()
last = drnn0()
logits = fluid.layers.fc(input=last, size=1, act=None)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
loss = fluid.layers.mean(loss)
sgd = fluid.optimizer.SGD(1e-3)
#sgd = fluid.optimizer.Adam(1e-3)
sgd.minimize(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(startup_program)
feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu)
data = next(train_data())
val = exe.run(main_program, feed=feeder.feed(data),
fetch_list=[loss])[0]
for _ in range(100):
val = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])[0]
train_data_orig = self.train_data
self.train_data = paddle.batch(self._fake_reader, batch_size=2)
self._train(
main_program=main_program,
startup_program=startup_program,
feed_list=[sentence, label],
fetch_list=[loss],
is_nested=True,
max_iters=100)
self.train_data = train_data_orig
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册