提交 89ead8d1 编写于 作者: Y Yu Yang 提交者: Yi Wang

Feature/understand sentiment parallel do (#7994)

* Support parallel test for understand_sentiment

* Full test on understand_sentiment

* Skip normal tests

* Debug CI

* Enable benchmark

* Revert init.cc

* Make CI pass
上级 4f122c07
...@@ -126,6 +126,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -126,6 +126,7 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
platform::RecordEvent record_event(op->Type(), pool.Get(place_)); platform::RecordEvent record_event(op->Type(), pool.Get(place_));
op->Run(*local_scope, place_); op->Run(*local_scope, place_);
// Wait current device context.
VLOG(3) << op->DebugStringEx(local_scope); VLOG(3) << op->DebugStringEx(local_scope);
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import unittest import unittest
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
...@@ -23,7 +24,8 @@ import sys ...@@ -23,7 +24,8 @@ import sys
def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32, def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32,
hid_dim=32): hid_dim=32):
emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) emb = fluid.layers.embedding(
input=data, size=[input_dim, emb_dim], is_sparse=True)
conv_3 = fluid.nets.sequence_conv_pool( conv_3 = fluid.nets.sequence_conv_pool(
input=emb, input=emb,
num_filters=hid_dim, num_filters=hid_dim,
...@@ -41,8 +43,6 @@ def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32, ...@@ -41,8 +43,6 @@ def convolution_net(data, label, input_dim, class_dim=2, emb_dim=32,
act="softmax") act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002)
adam_optimizer.minimize(avg_cost)
accuracy = fluid.layers.accuracy(input=prediction, label=label) accuracy = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, accuracy, prediction return avg_cost, accuracy, prediction
...@@ -56,7 +56,8 @@ def stacked_lstm_net(data, ...@@ -56,7 +56,8 @@ def stacked_lstm_net(data,
stacked_num=3): stacked_num=3):
assert stacked_num % 2 == 1 assert stacked_num % 2 == 1
emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) emb = fluid.layers.embedding(
input=data, size=[input_dim, emb_dim], is_sparse=True)
# add bias attr # add bias attr
# TODO(qijun) linear act # TODO(qijun) linear act
...@@ -79,8 +80,6 @@ def stacked_lstm_net(data, ...@@ -79,8 +80,6 @@ def stacked_lstm_net(data,
act='softmax') act='softmax')
cost = fluid.layers.cross_entropy(input=prediction, label=label) cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002)
adam_optimizer.minimize(avg_cost)
accuracy = fluid.layers.accuracy(input=prediction, label=label) accuracy = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, accuracy, prediction return avg_cost, accuracy, prediction
...@@ -93,7 +92,7 @@ def create_random_lodtensor(lod, place, low, high): ...@@ -93,7 +92,7 @@ def create_random_lodtensor(lod, place, low, high):
return res return res
def train(word_dict, net_method, use_cuda, save_dirname=None): def train(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
BATCH_SIZE = 128 BATCH_SIZE = 128
PASS_NUM = 5 PASS_NUM = 5
dict_dim = len(word_dict) dict_dim = len(word_dict)
...@@ -102,8 +101,30 @@ def train(word_dict, net_method, use_cuda, save_dirname=None): ...@@ -102,8 +101,30 @@ def train(word_dict, net_method, use_cuda, save_dirname=None):
data = fluid.layers.data( data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1) name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64") label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost, acc_out, prediction = net_method(
data, label, input_dim=dict_dim, class_dim=class_dim) if not parallel:
cost, acc_out, prediction = net_method(
data, label, input_dim=dict_dim, class_dim=class_dim)
else:
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
cost, acc, _ = net_method(
pd.read_input(data),
pd.read_input(label),
input_dim=dict_dim,
class_dim=class_dim)
pd.write_output(cost)
pd.write_output(acc)
cost, acc = pd()
cost = fluid.layers.mean(x=cost)
acc_out = fluid.layers.mean(x=acc)
prediction = None
assert save_dirname is None
adagrad = fluid.optimizer.Adagrad(learning_rate=0.002)
adagrad.minimize(cost)
train_data = paddle.batch( train_data = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -164,14 +185,16 @@ def infer(use_cuda, save_dirname=None): ...@@ -164,14 +185,16 @@ def infer(use_cuda, save_dirname=None):
print("Inference results: ", np_data) print("Inference results: ", np_data)
def main(word_dict, net_method, use_cuda): def main(word_dict, net_method, use_cuda, parallel=False, save_dirname=None):
if use_cuda and not fluid.core.is_compiled_with_cuda(): if use_cuda and not fluid.core.is_compiled_with_cuda():
return return
# Directory for saving the trained model train(
save_dirname = "understand_sentiment.inference.model" word_dict,
net_method,
train(word_dict, net_method, use_cuda, save_dirname) use_cuda,
parallel=parallel,
save_dirname=save_dirname)
infer(use_cuda, save_dirname) infer(use_cuda, save_dirname)
...@@ -191,20 +214,62 @@ class TestUnderstandSentiment(unittest.TestCase): ...@@ -191,20 +214,62 @@ class TestUnderstandSentiment(unittest.TestCase):
def test_conv_cpu(self): def test_conv_cpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=convolution_net, use_cuda=False) main(
self.word_dict,
net_method=convolution_net,
use_cuda=False,
save_dirname="understand_sentiment.inference.model")
def test_conv_cpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=convolution_net,
use_cuda=False,
parallel=True)
@unittest.skip(reason="make CI faster")
def test_stacked_lstm_cpu(self): def test_stacked_lstm_cpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=False) main(self.word_dict, net_method=stacked_lstm_net, use_cuda=False)
def test_stacked_lstm_cpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=False,
parallel=True)
def test_conv_gpu(self): def test_conv_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=convolution_net, use_cuda=True) main(
self.word_dict,
net_method=convolution_net,
use_cuda=True,
save_dirname="understand_sentiment.inference.model")
def test_conv_gpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=convolution_net,
use_cuda=True,
parallel=True)
@unittest.skip(reason="make CI faster")
def test_stacked_lstm_gpu(self): def test_stacked_lstm_gpu(self):
with self.new_program_scope(): with self.new_program_scope():
main(self.word_dict, net_method=stacked_lstm_net, use_cuda=True) main(self.word_dict, net_method=stacked_lstm_net, use_cuda=True)
def test_stacked_lstm_gpu_parallel(self):
with self.new_program_scope():
main(
self.word_dict,
net_method=stacked_lstm_net,
use_cuda=True,
parallel=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册