From 0ac8c74e630d3fd0c3d9cad7cf3207973e970111 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 27 Nov 2017 13:45:34 +0800 Subject: [PATCH] Unify fluid submodules to fluid module (#5924) Change books just use `import fluid`, not submodules --- python/paddle/v2/fluid/__init__.py | 52 +++++++--- python/paddle/v2/fluid/evaluator.py | 7 +- python/paddle/v2/fluid/executor.py | 6 +- python/paddle/v2/fluid/framework.py | 8 +- python/paddle/v2/fluid/initializer.py | 23 ++++- python/paddle/v2/fluid/layer_helper.py | 11 +-- python/paddle/v2/fluid/layers.py | 36 ++++--- python/paddle/v2/fluid/nets.py | 2 +- python/paddle/v2/fluid/optimizer.py | 43 ++++++--- python/paddle/v2/fluid/regularizer.py | 19 +++- .../v2/fluid/tests/book/test_fit_a_line.py | 57 +++++------ .../book/test_image_classification_train.py | 95 +++++++------------ .../tests/book/test_label_semantic_roles.py | 72 +++++++------- .../tests/book/test_recognize_digits_conv.py | 50 ++++------ .../tests/book/test_recognize_digits_mlp.py | 77 +++++++-------- .../book/test_understand_sentiment_conv.py | 54 +++++------ .../test_understand_sentiment_dynamic_lstm.py | 60 ++++++------ .../book/test_understand_sentiment_lstm.py | 49 +++++----- .../v2/fluid/tests/book/test_word2vec.py | 85 ++++++----------- 19 files changed, 381 insertions(+), 425 deletions(-) diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index 5df612bf3..9677c9568 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -1,11 +1,41 @@ -import sys -import core -__all__ = ['proto'] -argv = [] -if core.is_compile_gpu(): - argv = list(sys.argv) + [ - "--tryfromenv=fraction_of_gpu_memory_to_use,use_pinned_memory" - ] -else: - argv = list(sys.argv) + ["--tryfromenv=use_pinned_memory"] -core.init_gflags(argv) +# import all class inside framework into fluid module +import framework +from framework import * +# import all class inside executor into fluid module +import executor +from executor import * + +import io +import evaluator +import initializer +import layers +import nets +import optimizer +import backward +import regularizer + +from core import LoDTensor, CPUPlace, GPUPlace + +Tensor = LoDTensor +__all__ = framework.__all__ + executor.__all__ + [ + 'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward', + 'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor' +] + + +def __read_gflags_from_env__(): + """ + Enable reading gflags from environment variables. + + Returns: + None + """ + import sys + import core + read_env_flags = ['use_pinned_memory'] + if core.is_compile_gpu(): + read_env_flags.append('fraction_of_gpu_memory_to_use') + core.init_gflags(sys.argv + ["--tryfromenv=" + ",".join(read_env_flags)]) + + +__read_gflags_from_env__() diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index c37fca856..bd4a6fda1 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -1,9 +1,8 @@ import numpy as np -import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.framework import Program, unique_name, \ - Variable -from paddle.v2.fluid.layer_helper import LayerHelper +import layers +from framework import Program, unique_name, Variable +from layer_helper import LayerHelper __all__ = ['Accuracy'] diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index bd98d6b15..3e26d1b98 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -1,6 +1,8 @@ import numpy as np -import paddle.v2.fluid.core as core -from paddle.v2.fluid.framework import Block, Program, g_main_program +from . import core +from framework import Program, g_main_program + +__all__ = ['Executor', 'g_scope'] g_scope = core.Scope() diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 872c19c2f..9a62698b8 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -1,12 +1,12 @@ -import paddle.v2.fluid.core as core -import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 import collections + import numpy as np -import copy +from . import core +import proto.framework_pb2 as framework_pb2 __all__ = [ 'Block', 'Variable', 'Program', 'Operator', 'default_startup_program', - 'default_main_program' + 'default_main_program', 'g_startup_program', 'g_main_program' ] diff --git a/python/paddle/v2/fluid/initializer.py b/python/paddle/v2/fluid/initializer.py index 9f23e68a7..d3f648f84 100644 --- a/python/paddle/v2/fluid/initializer.py +++ b/python/paddle/v2/fluid/initializer.py @@ -1,10 +1,7 @@ -import paddle.v2.fluid.framework as framework +import framework import numpy as np -__all__ = [ - 'ConstantInitializer', 'UniformInitializer', 'NormalInitializer', - 'XavierInitializer' -] +__all__ = ['Constant', 'Uniform', 'Normal', 'Xavier'] class Initializer(object): @@ -368,3 +365,19 @@ class MSRAInitializer(Initializer): }) var.op = op return op + + +# We short the class name, since users will use the initializer with the package +# name. The sample code: +# +# import paddle.fluid as fluid +# +# hidden = fluid.layers.fc(..., +# param_attr=ParamAttr(fluid.initializer.Xavier())) +# +# It is no need to add an `Initializer` as the class suffix +Constant = ConstantInitializer +Uniform = UniformInitializer +Normal = NormalInitializer +Xavier = XavierInitializer +MSRA = MSRAInitializer diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py index e0880354f..5f8855551 100644 --- a/python/paddle/v2/fluid/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -1,10 +1,9 @@ import copy import itertools -from paddle.v2.fluid.framework import Variable, g_main_program, \ - g_startup_program, unique_name, Program, dtype_is_floating -from paddle.v2.fluid.initializer import ConstantInitializer, \ - UniformInitializer, XavierInitializer +from framework import Variable, g_main_program, \ + g_startup_program, unique_name, dtype_is_floating +from paddle.v2.fluid.initializer import Constant, Xavier class LayerHelper(object): @@ -209,7 +208,7 @@ class LayerHelper(object): def _get_default_initializer(self, dtype): if dtype is None or dtype_is_floating(dtype) is True: - return XavierInitializer() + return Xavier() else: # For integer and boolean types, initialize with all zeros - return ConstantInitializer() + return Constant() diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index ca0c10e70..db388c142 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -1,9 +1,7 @@ -import paddle.v2.fluid.core as core -import paddle.v2.fluid.proto.framework_pb2 as framework_pb2 -from paddle.v2.fluid.framework import OpProtoHolder, Variable, Program, \ - Operator -from paddle.v2.fluid.initializer import ConstantInitializer, \ - NormalInitializer, XavierInitializer +from . import core +import proto.framework_pb2 as framework_pb2 +from framework import OpProtoHolder, Variable, Program, Operator +from initializer import Constant, Normal, Xavier from paddle.v2.fluid.layer_helper import LayerHelper, unique_name import re import cStringIO @@ -58,10 +56,10 @@ def fc(input, """ def _get_default_param_initializer(): - return XavierInitializer() + return Xavier() def _get_default_bias_initializer(): - return ConstantInitializer() + return Constant() helper = LayerHelper('fc', **locals()) @@ -139,7 +137,7 @@ def embedding(input, """ def _get_default_param_initializer(): - return XavierInitializer() + return Xavier() helper = LayerHelper('embedding', **locals()) w = helper.create_parameter( @@ -477,7 +475,7 @@ def linear_chain_crf(input, main_program=None, startup_program=None): def _get_default_param_initializer(): - return XavierInitializer() + return Xavier() helper = LayerHelper('linear_chain_crf', **locals()) size = input.shape[1] @@ -661,10 +659,10 @@ def sequence_conv(input, """ def _get_default_bias_initializer(): - return ConstantInitializer() + return Constant() def _get_default_param_initializer(): - return XavierInitializer() + return Xavier() # FIXME(dzh) : want to unify the argument of python layer # function. So we ignore some unecessary attributes. @@ -725,11 +723,11 @@ def conv2d(input, """ def _get_default_bias_initializer(): - return ConstantInitializer() + return Constant() def _get_default_param_initializer(filter_size, num_channels): std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 - return NormalInitializer(0.0, std, 0) + return Normal(0.0, std, 0) helper = LayerHelper('conv2d', **locals()) dtype = helper.input_dtype() @@ -878,22 +876,20 @@ def batch_norm(input, attr=helper.param_attr, shape=param_shape, dtype=dtype, - initializer=ConstantInitializer(1.0)) + initializer=Constant(1.0)) bias = helper.create_parameter( attr=helper.param_attr, shape=param_shape, dtype=dtype, - initializer=ConstantInitializer(0.0)) + initializer=Constant(0.0)) mean = helper.create_global_variable( dtype=input.dtype, shape=param_shape, persistable=True) - helper.set_variable_initializer( - var=mean, initializer=ConstantInitializer(0.0)) + helper.set_variable_initializer(var=mean, initializer=Constant(0.0)) variance = helper.create_global_variable( dtype=input.dtype, shape=param_shape, persistable=True) - helper.set_variable_initializer( - var=variance, initializer=ConstantInitializer(1.0)) + helper.set_variable_initializer(var=variance, initializer=Constant(1.0)) # create output # mean and mean_out share the same memory diff --git a/python/paddle/v2/fluid/nets.py b/python/paddle/v2/fluid/nets.py index 5e14ca594..05728ad75 100644 --- a/python/paddle/v2/fluid/nets.py +++ b/python/paddle/v2/fluid/nets.py @@ -1,4 +1,4 @@ -import paddle.v2.fluid.layers as layers +import layers __all__ = ["simple_img_conv_pool", "sequence_conv_pool"] diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index e82f0f060..934e02474 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -1,16 +1,13 @@ from collections import defaultdict -import paddle.v2.fluid.framework as framework -from paddle.v2.fluid.framework import unique_name, Program -from paddle.v2.fluid.backward import append_backward_ops -from paddle.v2.fluid.initializer import ConstantInitializer -from paddle.v2.fluid.regularizer import append_regularization_ops -from paddle.v2.fluid.layer_helper import LayerHelper +import framework +from backward import append_backward_ops +from framework import unique_name +from initializer import Constant +from layer_helper import LayerHelper +from regularizer import append_regularization_ops -__all__ = [ - 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', - 'AdamaxOptimizer', 'DecayedAdagradOptimizer' -] +__all__ = ['SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad'] class Optimizer(object): @@ -48,7 +45,7 @@ class Optimizer(object): persistable=True) param_lr = param_lr * self._learning_rate self.helper.set_variable_initializer( - var=param_lr_var, initializer=ConstantInitializer(param_lr)) + var=param_lr_var, initializer=Constant(param_lr)) return param_lr_var def _create_accumulators(self, block, parameters): @@ -96,7 +93,7 @@ class Optimizer(object): type=param.type, shape=param.shape) self.helper.set_variable_initializer( - var, initializer=ConstantInitializer(value=float(fill_value))) + var, initializer=Constant(value=float(fill_value))) self._accumulators[name][param.name] = var def _get_accumulator(self, name, param): @@ -360,7 +357,7 @@ class AdamOptimizer(Optimizer): lod_level=0, persistable=True) self.helper.set_variable_initializer( - self._beta1_pow_acc, initializer=ConstantInitializer(self._beta1)) + self._beta1_pow_acc, initializer=Constant(self._beta1)) self._beta2_pow_acc = self.helper.create_global_variable( name=unique_name('beta2_pow_acc'), @@ -370,7 +367,7 @@ class AdamOptimizer(Optimizer): persistable=True) self.helper.set_variable_initializer( - self._beta2_pow_acc, initializer=ConstantInitializer(self._beta2)) + self._beta2_pow_acc, initializer=Constant(self._beta2)) # Create accumulator tensors for first and second moments for p in parameters: @@ -462,7 +459,7 @@ class AdamaxOptimizer(Optimizer): lod_level=0, persistable=True) self.helper.set_variable_initializer( - self._beta1_pow_acc, initializer=ConstantInitializer(self._beta1)) + self._beta1_pow_acc, initializer=Constant(self._beta1)) # Create accumulator tensors for first moment and infinity norm for p in parameters: @@ -559,3 +556,19 @@ class DecayedAdagradOptimizer(Optimizer): attrs={"epsilon": self._epsilon}) return decayed_adagrad_op + + +# We short the class name, since users will use the optimizer with the package +# name. The sample code: +# +# import paddle.fluid as fluid +# +# sgd = fluid.optimizer.SGD(...) +# +# It is no need to add an `Optimizer` as the class suffix +SGD = SGDOptimizer +Momentum = MomentumOptimizer +Adagrad = AdagradOptimizer +Adam = AdamOptimizer +Adamax = AdamaxOptimizer +DecayedAdagrad = DecayedAdagradOptimizer diff --git a/python/paddle/v2/fluid/regularizer.py b/python/paddle/v2/fluid/regularizer.py index 098cd0dd6..c2c18e195 100644 --- a/python/paddle/v2/fluid/regularizer.py +++ b/python/paddle/v2/fluid/regularizer.py @@ -1,8 +1,6 @@ -import paddle.v2.fluid.framework as framework +import framework -__all__ = [ - 'append_regularization_ops', 'L2DecayRegularizer', 'L1DecayRegularizer' -] +__all__ = ['append_regularization_ops', 'L1Decay', 'L2Decay'] def append_regularization_ops(parameters_and_grads): @@ -139,3 +137,16 @@ class L1DecayRegularizer(WeightDecayRegularizer): attrs={"scale": self._regularization_coeff}) return decay + + +# We short the class name, since users will use the regulaizer with the package +# name. The sample code: +# +# import paddle.fluid as fluid +# +# hidden = fluid.layers.fc(..., +# param_attr=ParamAttr(fluid.regularizer.Xavier())) +# +# It is no need to add a `Regularizer` as the class suffix +L1Decay = L1DecayRegularizer +L2Decay = L2DecayRegularizer diff --git a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py index a899f1088..9f98493ad 100644 --- a/python/paddle/v2/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/v2/fluid/tests/book/test_fit_a_line.py @@ -1,23 +1,18 @@ import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.io import save_persistables, load_persistables -from paddle.v2.fluid.optimizer import SGDOptimizer +import paddle.v2.fluid as fluid -x = layers.data(name='x', shape=[13], dtype='float32') +x = fluid.layers.data(name='x', shape=[13], dtype='float32') -y_predict = layers.fc(input=x, size=1, act=None) +y_predict = fluid.layers.fc(input=x, size=1, act=None) -y = layers.data(name='y', shape=[1], dtype='float32') +y = fluid.layers.data(name='y', shape=[1], dtype='float32') -cost = layers.square_error_cost(input=y_predict, label=y) -avg_cost = layers.mean(x=cost) +cost = fluid.layers.square_error_cost(input=y_predict, label=y) +avg_cost = fluid.layers.mean(x=cost) -sgd_optimizer = SGDOptimizer(learning_rate=0.001) -opts = sgd_optimizer.minimize(avg_cost) +sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) +sgd_optimizer.minimize(avg_cost) BATCH_SIZE = 20 @@ -26,32 +21,24 @@ train_reader = paddle.batch( paddle.dataset.uci_housing.train(), buf_size=500), batch_size=BATCH_SIZE) -place = core.CPUPlace() -exe = Executor(place) +place = fluid.CPUPlace() +exe = fluid.Executor(place) -exe.run(framework.default_startup_program()) +exe.run(fluid.default_startup_program()) PASS_NUM = 100 for pass_id in range(PASS_NUM): - save_persistables(exe, "./fit_a_line.model/") - load_persistables(exe, "./fit_a_line.model/") + fluid.io.save_persistables(exe, "./fit_a_line.model/") + fluid.io.load_persistables(exe, "./fit_a_line.model/") for data in train_reader(): - x_data = np.array(map(lambda x: x[0], data)).astype("float32") - y_data = np.array(map(lambda x: x[1], data)).astype("float32") - - tensor_x = core.LoDTensor() - tensor_x.set(x_data, place) - # print tensor_x.get_dims() - - tensor_y = core.LoDTensor() - tensor_y.set(y_data, place) - # print tensor_y.get_dims() - outs = exe.run(framework.default_main_program(), - feed={'x': tensor_x, - 'y': tensor_y}, - fetch_list=[avg_cost]) - out = np.array(outs[0]) - - if out[0] < 10.0: + x_data = np.array(map(lambda _: _[0], data)).astype("float32") + y_data = np.array(map(lambda _: _[1], data)).astype("float32") + + avg_loss_value, = exe.run(fluid.default_main_program(), + feed={'x': x_data, + 'y': y_data}, + fetch_list=[avg_cost]) + + if avg_loss_value[0] < 10.0: exit(0) # if avg cost less than 10.0, we think our code is good. exit(1) diff --git a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py index b555b49ab..690c53397 100644 --- a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -1,18 +1,12 @@ +from __future__ import print_function import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -import paddle.v2.fluid.nets as nets -import paddle.v2.fluid.evaluator as evaluator -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.initializer import XavierInitializer -from paddle.v2.fluid.optimizer import AdamOptimizer +import paddle.v2.fluid as fluid def resnet_cifar10(input, depth=32): def conv_bn_layer(input, ch_out, filter_size, stride, padding, act='relu'): - tmp = layers.conv2d( + tmp = fluid.layers.conv2d( input=input, filter_size=filter_size, num_filters=ch_out, @@ -20,12 +14,11 @@ def resnet_cifar10(input, depth=32): padding=padding, act=None, bias_attr=False) - return layers.batch_norm(input=tmp, act=act) + return fluid.layers.batch_norm(input=tmp, act=act) - def shortcut(input, ch_in, ch_out, stride, program, init_program): + def shortcut(input, ch_in, ch_out, stride): if ch_in != ch_out: - return conv_bn_layer(input, ch_out, 1, stride, 0, None, program, - init_program) + return conv_bn_layer(input, ch_out, 1, stride, 0, None) else: return input @@ -33,7 +26,7 @@ def resnet_cifar10(input, depth=32): tmp = conv_bn_layer(input, ch_out, 3, stride, 1) tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None) short = shortcut(input, ch_in, ch_out, stride) - return layers.elementwise_add(x=tmp, y=short, act='relu') + return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') def layer_warp(block_func, input, ch_in, ch_out, count, stride): tmp = block_func(input, ch_in, ch_out, stride) @@ -48,14 +41,14 @@ def resnet_cifar10(input, depth=32): res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) res2 = layer_warp(basicblock, res1, 16, 32, n, 2) res3 = layer_warp(basicblock, res2, 32, 64, n, 2) - pool = layers.pool2d( + pool = fluid.layers.pool2d( input=res3, pool_size=8, pool_type='avg', pool_stride=1) return pool def vgg16_bn_drop(input): def conv_block(input, num_filter, groups, dropouts): - return nets.img_conv_group( + return fluid.nets.img_conv_group( input=input, pool_size=2, pool_stride=2, @@ -72,26 +65,20 @@ def vgg16_bn_drop(input): conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) - drop = layers.dropout(x=conv5, dropout_prob=0.5) - fc1 = layers.fc(input=drop, - size=512, - act=None, - param_attr={"initializer": XavierInitializer()}) - reshape1 = layers.reshape(x=fc1, shape=list(fc1.shape + (1, 1))) - bn = layers.batch_norm(input=reshape1, act='relu') - drop2 = layers.dropout(x=bn, dropout_prob=0.5) - fc2 = layers.fc(input=drop2, - size=512, - act=None, - param_attr={"initializer": XavierInitializer()}) + drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) + fc1 = fluid.layers.fc(input=drop, size=512, act=None) + reshape1 = fluid.layers.reshape(x=fc1, shape=list(fc1.shape + (1, 1))) + bn = fluid.layers.batch_norm(input=reshape1, act='relu') + drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) + fc2 = fluid.layers.fc(input=drop2, size=512, act=None) return fc2 classdim = 10 data_shape = [3, 32, 32] -images = layers.data(name='pixel', shape=data_shape, dtype='float32') -label = layers.data(name='label', shape=[1], dtype='int64') +images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') +label = fluid.layers.data(name='label', shape=[1], dtype='int64') # Add neural network config # option 1. resnet @@ -99,17 +86,14 @@ label = layers.data(name='label', shape=[1], dtype='int64') # option 2. vgg net = vgg16_bn_drop(images) -# print(program) +predict = fluid.layers.fc(input=net, size=classdim, act='softmax') +cost = fluid.layers.cross_entropy(input=predict, label=label) +avg_cost = fluid.layers.mean(x=cost) -predict = layers.fc(input=net, size=classdim, act='softmax') -cost = layers.cross_entropy(input=predict, label=label) -avg_cost = layers.mean(x=cost) - -# optimizer = SGDOptimizer(learning_rate=0.001) -optimizer = AdamOptimizer(learning_rate=0.001) +optimizer = fluid.optimizer.Adam(learning_rate=0.001) opts = optimizer.minimize(avg_cost) -accuracy = evaluator.Accuracy(input=predict, label=label) +accuracy = fluid.evaluator.Accuracy(input=predict, label=label) BATCH_SIZE = 128 PASS_NUM = 1 @@ -119,13 +103,12 @@ train_reader = paddle.batch( paddle.dataset.cifar.train10(), buf_size=128 * 10), batch_size=BATCH_SIZE) -place = core.CPUPlace() -exe = Executor(place) +place = fluid.CPUPlace() +exe = fluid.Executor(place) -exe.run(framework.default_startup_program()) +exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): - batch_id = 0 accuracy.reset(exe) for data in train_reader(): img_data = np.array(map(lambda x: x[0].reshape(data_shape), @@ -136,25 +119,13 @@ for pass_id in range(PASS_NUM): batch_size = batch_size * i y_data = y_data.reshape([batch_size, 1]) - tensor_img = core.LoDTensor() - tensor_y = core.LoDTensor() - tensor_img.set(img_data, place) - tensor_y.set(y_data, place) - - outs = exe.run(framework.default_main_program(), - feed={"pixel": tensor_img, - "label": tensor_y}, - fetch_list=[avg_cost] + accuracy.metrics) - - loss = np.array(outs[0]) - acc = np.array(outs[1]) + loss, acc = exe.run(fluid.default_main_program(), + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[avg_cost] + accuracy.metrics) pass_acc = accuracy.eval(exe) - print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) + - " loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( - pass_acc)) - batch_id = batch_id + 1 - - if batch_id > 1: - # this model is slow, so if we can train two mini batch, we think it works properly. - exit(0) + print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( + pass_acc)) + # this model is slow, so if we can train two mini batch, we think it works properly. + exit(0) exit(1) diff --git a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py index 9c9064ba9..93987a2b8 100644 --- a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py @@ -1,11 +1,7 @@ import numpy as np import paddle.v2 as paddle import paddle.v2.dataset.conll05 as conll05 -import paddle.v2.fluid.core as core -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.executor import Executor, g_scope -from paddle.v2.fluid.optimizer import SGDOptimizer +import paddle.v2.fluid as fluid word_dict, verb_dict, label_dict = conll05.get_dict() word_dict_len = len(word_dict) @@ -34,23 +30,23 @@ def load_parameter(file_name, h, w): def db_lstm(): # 8 features - word = layers.data(name='word_data', shape=[1], dtype='int64') - predicate = layers.data(name='verb_data', shape=[1], dtype='int64') - ctx_n2 = layers.data(name='ctx_n2_data', shape=[1], dtype='int64') - ctx_n1 = layers.data(name='ctx_n1_data', shape=[1], dtype='int64') - ctx_0 = layers.data(name='ctx_0_data', shape=[1], dtype='int64') - ctx_p1 = layers.data(name='ctx_p1_data', shape=[1], dtype='int64') - ctx_p2 = layers.data(name='ctx_p2_data', shape=[1], dtype='int64') - mark = layers.data(name='mark_data', shape=[1], dtype='int64') - - predicate_embedding = layers.embedding( + word = fluid.layers.data(name='word_data', shape=[1], dtype='int64') + predicate = fluid.layers.data(name='verb_data', shape=[1], dtype='int64') + ctx_n2 = fluid.layers.data(name='ctx_n2_data', shape=[1], dtype='int64') + ctx_n1 = fluid.layers.data(name='ctx_n1_data', shape=[1], dtype='int64') + ctx_0 = fluid.layers.data(name='ctx_0_data', shape=[1], dtype='int64') + ctx_p1 = fluid.layers.data(name='ctx_p1_data', shape=[1], dtype='int64') + ctx_p2 = fluid.layers.data(name='ctx_p2_data', shape=[1], dtype='int64') + mark = fluid.layers.data(name='mark_data', shape=[1], dtype='int64') + + predicate_embedding = fluid.layers.embedding( input=predicate, size=[pred_len, word_dim], dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'vemb'}) - mark_embedding = layers.embedding( + mark_embedding = fluid.layers.embedding( input=mark, size=[mark_dict_len, mark_dim], dtype='float32', @@ -58,7 +54,7 @@ def db_lstm(): word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] emb_layers = [ - layers.embedding( + fluid.layers.embedding( size=[word_dict_len, word_dim], input=x, param_attr={'name': embedding_name, @@ -68,12 +64,12 @@ def db_lstm(): emb_layers.append(mark_embedding) hidden_0_layers = [ - layers.fc(input=emb, size=hidden_dim) for emb in emb_layers + fluid.layers.fc(input=emb, size=hidden_dim) for emb in emb_layers ] - hidden_0 = layers.sums(input=hidden_0_layers) + hidden_0 = fluid.layers.sums(input=hidden_0_layers) - lstm_0 = layers.dynamic_lstm( + lstm_0 = fluid.layers.dynamic_lstm( input=hidden_0, size=hidden_dim, candidate_activation='relu', @@ -84,12 +80,12 @@ def db_lstm(): input_tmp = [hidden_0, lstm_0] for i in range(1, depth): - mix_hidden = layers.sums(input=[ - layers.fc(input=input_tmp[0], size=hidden_dim), - layers.fc(input=input_tmp[1], size=hidden_dim) + mix_hidden = fluid.layers.sums(input=[ + fluid.layers.fc(input=input_tmp[0], size=hidden_dim), + fluid.layers.fc(input=input_tmp[1], size=hidden_dim) ]) - lstm = layers.dynamic_lstm( + lstm = fluid.layers.dynamic_lstm( input=mix_hidden, size=hidden_dim, candidate_activation='relu', @@ -99,9 +95,9 @@ def db_lstm(): input_tmp = [mix_hidden, lstm] - feature_out = layers.sums(input=[ - layers.fc(input=input_tmp[0], size=label_dict_len), - layers.fc(input=input_tmp[1], size=label_dict_len) + feature_out = fluid.layers.sums(input=[ + fluid.layers.fc(input=input_tmp[0], size=label_dict_len), + fluid.layers.fc(input=input_tmp[1], size=label_dict_len) ]) return feature_out @@ -116,7 +112,7 @@ def to_lodtensor(data, place): lod.append(cur_len) flattened_data = np.concatenate(data, axis=0).astype("int64") flattened_data = flattened_data.reshape([len(flattened_data), 1]) - res = core.LoDTensor() + res = fluid.LoDTensor() res.set(flattened_data, place) res.set_lod([lod]) return res @@ -125,29 +121,29 @@ def to_lodtensor(data, place): def main(): # define network topology feature_out = db_lstm() - target = layers.data(name='target', shape=[1], dtype='int64') - crf_cost = layers.linear_chain_crf( + target = fluid.layers.data(name='target', shape=[1], dtype='int64') + crf_cost = fluid.layers.linear_chain_crf( input=feature_out, label=target, param_attr={"name": 'crfw', "learning_rate": mix_hidden_lr}) - avg_cost = layers.mean(x=crf_cost) + avg_cost = fluid.layers.mean(x=crf_cost) # TODO(qiao) # 1. add crf_decode_layer and evaluator # 2. use other optimizer and check why out will be NAN - sgd_optimizer = SGDOptimizer(learning_rate=0.0001) - opts = sgd_optimizer.minimize(avg_cost) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.0001) + sgd_optimizer.minimize(avg_cost) train_data = paddle.batch( paddle.reader.shuffle( paddle.dataset.conll05.test(), buf_size=8192), batch_size=BATCH_SIZE) - place = core.CPUPlace() - exe = Executor(place) + place = fluid.CPUPlace() + exe = fluid.Executor(place) - exe.run(framework.default_startup_program()) + exe.run(fluid.default_startup_program()) - embedding_param = g_scope.find_var(embedding_name).get_tensor() + embedding_param = fluid.g_scope.find_var(embedding_name).get_tensor() embedding_param.set( load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place) @@ -164,7 +160,7 @@ def main(): mark_data = to_lodtensor(map(lambda x: x[7], data), place) target = to_lodtensor(map(lambda x: x[8], data), place) - outs = exe.run(framework.default_main_program(), + outs = exe.run(fluid.default_main_program(), feed={ 'word_data': word_data, 'ctx_n2_data': ctx_n2_data, diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py index 97f1f1272..ba686b56f 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py @@ -1,23 +1,18 @@ +from __future__ import print_function import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.evaluator as evaluator -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -import paddle.v2.fluid.nets as nets -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.optimizer import AdamOptimizer +import paddle.v2.fluid as fluid -images = layers.data(name='pixel', shape=[1, 28, 28], dtype='float32') -label = layers.data(name='label', shape=[1], dtype='int64') -conv_pool_1 = nets.simple_img_conv_pool( +images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32') +label = fluid.layers.data(name='label', shape=[1], dtype='int64') +conv_pool_1 = fluid.nets.simple_img_conv_pool( input=images, filter_size=5, num_filters=20, pool_size=2, pool_stride=2, act="relu") -conv_pool_2 = nets.simple_img_conv_pool( +conv_pool_2 = fluid.nets.simple_img_conv_pool( input=conv_pool_1, filter_size=5, num_filters=50, @@ -25,13 +20,13 @@ conv_pool_2 = nets.simple_img_conv_pool( pool_stride=2, act="relu") -predict = layers.fc(input=conv_pool_2, size=10, act="softmax") -cost = layers.cross_entropy(input=predict, label=label) -avg_cost = layers.mean(x=cost) -optimizer = AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999) -opts = optimizer.minimize(avg_cost) +predict = fluid.layers.fc(input=conv_pool_2, size=10, act="softmax") +cost = fluid.layers.cross_entropy(input=predict, label=label) +avg_cost = fluid.layers.mean(x=cost) +optimizer = fluid.optimizer.Adam(learning_rate=0.01) +optimizer.minimize(avg_cost) -accuracy = evaluator.Accuracy(input=predict, label=label) +accuracy = fluid.evaluator.Accuracy(input=predict, label=label) BATCH_SIZE = 50 PASS_NUM = 3 @@ -40,10 +35,10 @@ train_reader = paddle.batch( paddle.dataset.mnist.train(), buf_size=500), batch_size=BATCH_SIZE) -place = core.CPUPlace() -exe = Executor(place) +place = fluid.CPUPlace() +exe = fluid.Executor(place) -exe.run(framework.default_startup_program()) +exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): accuracy.reset(exe) @@ -53,17 +48,10 @@ for pass_id in range(PASS_NUM): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([BATCH_SIZE, 1]) - tensor_img = core.LoDTensor() - tensor_y = core.LoDTensor() - tensor_img.set(img_data, place) - tensor_y.set(y_data, place) - - outs = exe.run(framework.default_main_program(), - feed={"pixel": tensor_img, - "label": tensor_y}, - fetch_list=[avg_cost] + accuracy.metrics) - loss = np.array(outs[0]) - acc = np.array(outs[1]) + loss, acc = exe.run(fluid.default_main_program(), + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[avg_cost] + accuracy.metrics) pass_acc = accuracy.eval(exe) print("pass_id=" + str(pass_id) + " acc=" + str(acc) + " pass_acc=" + str(pass_acc)) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index 7dbb34f5d..c96d186ff 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -1,42 +1,39 @@ +from __future__ import print_function import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -import paddle.v2.fluid.evaluator as evaluator -from paddle.v2.fluid.io import get_inference_program -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.initializer import UniformInitializer -from paddle.v2.fluid.optimizer import MomentumOptimizer -from paddle.v2.fluid.regularizer import L2DecayRegularizer +import paddle.v2.fluid as fluid BATCH_SIZE = 128 -image = layers.data(name='x', shape=[784], dtype='float32') +image = fluid.layers.data(name='x', shape=[784], dtype='float32') param_attr = { 'name': None, - 'initializer': UniformInitializer( - low=-1.0, high=1.0), - 'regularization': L2DecayRegularizer(0.0005 * BATCH_SIZE) + 'regularization': fluid.regularizer.L2Decay(0.0005 * BATCH_SIZE) } -hidden1 = layers.fc(input=image, size=128, act='relu', param_attr=param_attr) -hidden2 = layers.fc(input=hidden1, size=64, act='relu', param_attr=param_attr) +hidden1 = fluid.layers.fc(input=image, + size=128, + act='relu', + param_attr=param_attr) +hidden2 = fluid.layers.fc(input=hidden1, + size=64, + act='relu', + param_attr=param_attr) -predict = layers.fc(input=hidden2, - size=10, - act='softmax', - param_attr=param_attr) +predict = fluid.layers.fc(input=hidden2, + size=10, + act='softmax', + param_attr=param_attr) -label = layers.data(name='y', shape=[1], dtype='int64') +label = fluid.layers.data(name='y', shape=[1], dtype='int64') -cost = layers.cross_entropy(input=predict, label=label) -avg_cost = layers.mean(x=cost) +cost = fluid.layers.cross_entropy(input=predict, label=label) +avg_cost = fluid.layers.mean(x=cost) -optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9) +optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) opts = optimizer.minimize(avg_cost) -accuracy = evaluator.Accuracy(input=predict, label=label) +accuracy = fluid.evaluator.Accuracy(input=predict, label=label) train_reader = paddle.batch( paddle.reader.shuffle( @@ -45,10 +42,10 @@ train_reader = paddle.batch( test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) -place = core.CPUPlace() -exe = Executor(place) +place = fluid.CPUPlace() +exe = fluid.Executor(place) -exe.run(framework.default_startup_program()) +exe.run(fluid.default_startup_program()) PASS_NUM = 100 for pass_id in range(PASS_NUM): @@ -58,13 +55,13 @@ for pass_id in range(PASS_NUM): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = np.expand_dims(y_data, axis=1) - tensor_x = core.LoDTensor() + tensor_x = fluid.LoDTensor() tensor_x.set(x_data, place) - tensor_y = core.LoDTensor() + tensor_y = fluid.LoDTensor() tensor_y.set(y_data, place) - outs = exe.run(framework.default_main_program(), + outs = exe.run(fluid.default_main_program(), feed={'x': tensor_x, 'y': tensor_y}, fetch_list=[avg_cost] + accuracy.metrics) @@ -72,10 +69,10 @@ for pass_id in range(PASS_NUM): acc = np.array(outs[1]) pass_acc = accuracy.eval(exe) - test_accuracy = evaluator.Accuracy(input=predict, label=label) + test_accuracy = fluid.evaluator.Accuracy(input=predict, label=label) test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states - inference_program = get_inference_program(test_target) + inference_program = fluid.io.get_inference_program(test_target) test_accuracy.reset(exe) for data in test_reader(): @@ -83,18 +80,10 @@ for pass_id in range(PASS_NUM): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = np.expand_dims(y_data, axis=1) - tensor_x = core.LoDTensor() - tensor_x.set(x_data, place) - - tensor_y = core.LoDTensor() - tensor_y.set(y_data, place) - - outs = exe.run(inference_program, - feed={'x': tensor_x, - 'y': tensor_y}, - fetch_list=[avg_cost] + test_accuracy.metrics) - out = np.array(outs[0]) - acc = np.array(outs[1]) + out, acc = exe.run(inference_program, + feed={'x': x_data, + 'y': y_data}, + fetch_list=[avg_cost] + test_accuracy.metrics) test_pass_acc = test_accuracy.eval(exe) print("pass_id=" + str(pass_id) + " train_cost=" + str( diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py index 054cdb324..be875a952 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py @@ -1,39 +1,34 @@ +from __future__ import print_function import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.evaluator as evaluator -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -import paddle.v2.fluid.nets as nets -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.optimizer import AdamOptimizer +import paddle.v2.fluid as fluid def convolution_net(input_dim, class_dim=2, emb_dim=32, hid_dim=32): - data = layers.data(name="words", shape=[1], dtype="int64") - label = layers.data(name="label", shape=[1], dtype="int64") + data = fluid.layers.data(name="words", shape=[1], dtype="int64") + label = fluid.layers.data(name="label", shape=[1], dtype="int64") - emb = layers.embedding(input=data, size=[input_dim, emb_dim]) - conv_3 = nets.sequence_conv_pool( + emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) + conv_3 = fluid.nets.sequence_conv_pool( input=emb, num_filters=hid_dim, filter_size=3, act="tanh", pool_type="sqrt") - conv_4 = nets.sequence_conv_pool( + conv_4 = fluid.nets.sequence_conv_pool( input=emb, num_filters=hid_dim, filter_size=4, act="tanh", pool_type="sqrt") - prediction = layers.fc(input=[conv_3, conv_4], - size=class_dim, - act="softmax") - cost = layers.cross_entropy(input=prediction, label=label) - avg_cost = layers.mean(x=cost) - adam_optimizer = AdamOptimizer(learning_rate=0.002) + prediction = fluid.layers.fc(input=[conv_3, conv_4], + size=class_dim, + act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002) adam_optimizer.minimize(avg_cost) - accuracy = evaluator.Accuracy(input=prediction, label=label) + accuracy = fluid.evaluator.Accuracy(input=prediction, label=label) return avg_cost, accuracy, accuracy.metrics[0] @@ -46,7 +41,7 @@ def to_lodtensor(data, place): lod.append(cur_len) flattened_data = np.concatenate(data, axis=0).astype("int64") flattened_data = flattened_data.reshape([len(flattened_data), 1]) - res = core.LoDTensor() + res = fluid.LoDTensor() res.set(flattened_data, place) res.set_lod([lod]) return res @@ -67,10 +62,10 @@ def main(): paddle.reader.shuffle( paddle.dataset.imdb.train(word_dict), buf_size=1000), batch_size=BATCH_SIZE) - place = core.CPUPlace() - exe = Executor(place) + place = fluid.CPUPlace() + exe = fluid.Executor(place) - exe.run(framework.default_startup_program()) + exe.run(fluid.default_startup_program()) for pass_id in xrange(PASS_NUM): accuracy.reset(exe) @@ -80,15 +75,14 @@ def main(): label = np.array(map(lambda x: x[1], data)).astype("int64") label = label.reshape([BATCH_SIZE, 1]) - tensor_label = core.LoDTensor() + tensor_label = fluid.LoDTensor() tensor_label.set(label, place) - outs = exe.run(framework.default_main_program(), - feed={"words": tensor_words, - "label": tensor_label}, - fetch_list=[cost, acc_out]) - cost_val = np.array(outs[0]) - acc_val = np.array(outs[1]) + cost_val, acc_val = exe.run( + fluid.default_main_program(), + feed={"words": tensor_words, + "label": tensor_label}, + fetch_list=[cost, acc_out]) pass_acc = accuracy.eval(exe) print("cost=" + str(cost_val) + " acc=" + str(acc_val) + " pass_acc=" + str(pass_acc)) diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py index 854ef8261..094a3cdcd 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py @@ -1,11 +1,6 @@ import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.evaluator as evaluator -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.optimizer import AdamOptimizer +import paddle.v2.fluid as fluid def stacked_lstm_net(input_dim, @@ -14,35 +9,35 @@ def stacked_lstm_net(input_dim, hid_dim=512, stacked_num=3): assert stacked_num % 2 == 1 - data = layers.data(name="words", shape=[1], dtype="int64") - label = layers.data(name="label", shape=[1], dtype="int64") + data = fluid.layers.data(name="words", shape=[1], dtype="int64") + label = fluid.layers.data(name="label", shape=[1], dtype="int64") - emb = layers.embedding(input=data, size=[input_dim, emb_dim]) + emb = fluid.layers.embedding(input=data, size=[input_dim, emb_dim]) # add bias attr # TODO(qijun) linear act - fc1 = layers.fc(input=emb, size=hid_dim) - lstm1, cell1 = layers.dynamic_lstm(input=fc1, size=hid_dim) + fc1 = fluid.layers.fc(input=emb, size=hid_dim) + lstm1, cell1 = fluid.layers.dynamic_lstm(input=fc1, size=hid_dim) inputs = [fc1, lstm1] for i in range(2, stacked_num + 1): - fc = layers.fc(input=inputs, size=hid_dim) - lstm, cell = layers.dynamic_lstm( + fc = fluid.layers.fc(input=inputs, size=hid_dim) + lstm, cell = fluid.layers.dynamic_lstm( input=fc, size=hid_dim, is_reverse=(i % 2) == 0) inputs = [fc, lstm] - fc_last = layers.sequence_pool(input=inputs[0], pool_type='max') - lstm_last = layers.sequence_pool(input=inputs[1], pool_type='max') + fc_last = fluid.layers.sequence_pool(input=inputs[0], pool_type='max') + lstm_last = fluid.layers.sequence_pool(input=inputs[1], pool_type='max') - prediction = layers.fc(input=[fc_last, lstm_last], - size=class_dim, - act='softmax') - cost = layers.cross_entropy(input=prediction, label=label) - avg_cost = layers.mean(x=cost) - adam_optimizer = AdamOptimizer(learning_rate=0.002) + prediction = fluid.layers.fc(input=[fc_last, lstm_last], + size=class_dim, + act='softmax') + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002) adam_optimizer.minimize(avg_cost) - accuracy = evaluator.Accuracy(input=prediction, label=label) + accuracy = fluid.evaluator.Accuracy(input=prediction, label=label) return avg_cost, accuracy, accuracy.metrics[0] @@ -55,7 +50,7 @@ def to_lodtensor(data, place): lod.append(cur_len) flattened_data = np.concatenate(data, axis=0).astype("int64") flattened_data = flattened_data.reshape([len(flattened_data), 1]) - res = core.LoDTensor() + res = fluid.LoDTensor() res.set(flattened_data, place) res.set_lod([lod]) return res @@ -77,10 +72,10 @@ def main(): paddle.reader.shuffle( paddle.dataset.imdb.train(word_dict), buf_size=1000), batch_size=BATCH_SIZE) - place = core.CPUPlace() - exe = Executor(place) + place = fluid.CPUPlace() + exe = fluid.Executor(place) - exe.run(framework.default_startup_program()) + exe.run(fluid.default_startup_program()) for pass_id in xrange(PASS_NUM): accuracy.reset(exe) @@ -90,15 +85,14 @@ def main(): label = np.array(map(lambda x: x[1], data)).astype("int64") label = label.reshape([BATCH_SIZE, 1]) - tensor_label = core.LoDTensor() + tensor_label = fluid.LoDTensor() tensor_label.set(label, place) - outs = exe.run(framework.default_main_program(), - feed={"words": tensor_words, - "label": tensor_label}, - fetch_list=[cost, acc_out]) - cost_val = np.array(outs[0]) - acc_val = np.array(outs[1]) + cost_val, acc_val = exe.run( + fluid.default_main_program(), + feed={"words": tensor_words, + "label": tensor_label}, + fetch_list=[cost, acc_out]) pass_acc = accuracy.eval(exe) print("cost=" + str(cost_val) + " acc=" + str(acc_val) + " pass_acc=" + str(pass_acc)) diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py index 8aebeba65..b24793203 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_lstm.py @@ -1,40 +1,39 @@ import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.optimizer import AdamOptimizer +import paddle.v2.fluid as fluid def lstm_net(dict_dim, class_dim=2, emb_dim=32, seq_len=80, batch_size=50): - data = layers.data( + data = fluid.layers.data( name="words", shape=[seq_len * batch_size, 1], append_batch_size=False, dtype="int64") - label = layers.data( + label = fluid.layers.data( name="label", shape=[batch_size, 1], append_batch_size=False, dtype="int64") - emb = layers.embedding(input=data, size=[dict_dim, emb_dim]) - emb = layers.reshape(x=emb, shape=[batch_size, seq_len, emb_dim]) - emb = layers.transpose(x=emb, axis=[1, 0, 2]) + emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) + emb = fluid.layers.reshape(x=emb, shape=[batch_size, seq_len, emb_dim]) + emb = fluid.layers.transpose(x=emb, axis=[1, 0, 2]) - c_pre_init = layers.fill_constant( + c_pre_init = fluid.layers.fill_constant( dtype=emb.dtype, shape=[batch_size, emb_dim], value=0.0) - layer_1_out = layers.lstm(emb, c_pre_init=c_pre_init, hidden_dim=emb_dim) - layer_1_out = layers.transpose(x=layer_1_out, axis=[1, 0, 2]) + layer_1_out = fluid.layers.lstm( + emb, c_pre_init=c_pre_init, hidden_dim=emb_dim) + layer_1_out = fluid.layers.transpose(x=layer_1_out, axis=[1, 0, 2]) - prediction = layers.fc(input=layer_1_out, size=class_dim, act="softmax") - cost = layers.cross_entropy(input=prediction, label=label) + prediction = fluid.layers.fc(input=layer_1_out, + size=class_dim, + act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) - avg_cost = layers.mean(x=cost) - adam_optimizer = AdamOptimizer(learning_rate=0.002) - opts = adam_optimizer.minimize(avg_cost) - acc = layers.accuracy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + adam_optimizer = fluid.optimizer.Adam(learning_rate=0.002) + adam_optimizer.minimize(avg_cost) + acc = fluid.layers.accuracy(input=prediction, label=label) return avg_cost, acc @@ -48,7 +47,7 @@ def to_lodtensor(data, place): lod.append(cur_len) flattened_data = np.concatenate(data, axis=0).astype("int64") flattened_data = flattened_data.reshape([len(flattened_data), 1]) - res = core.LoDTensor() + res = fluid.LoDTensor() res.set(flattened_data, place) res.set_lod([lod]) return res @@ -65,7 +64,7 @@ def prepare_feed_data(data, place): label = np.array(map(lambda x: x[1], data)).astype("int64") label = label.reshape([len(label), 1]) - tensor_label = core.LoDTensor() + tensor_label = fluid.LoDTensor() tensor_label.set(label, place) return tensor_words, tensor_label @@ -86,17 +85,17 @@ def main(): paddle.reader.shuffle( paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10), batch_size=BATCH_SIZE) - place = core.CPUPlace() - exe = Executor(place) + place = fluid.CPUPlace() + exe = fluid.Executor(place) - exe.run(framework.default_startup_program()) + exe.run(fluid.default_startup_program()) for pass_id in xrange(PASS_NUM): for data in train_data(): chopped_data = chop_data(data) tensor_words, tensor_label = prepare_feed_data(chopped_data, place) - outs = exe.run(framework.default_main_program(), + outs = exe.run(fluid.default_main_program(), feed={"words": tensor_words, "label": tensor_label}, fetch_list=[cost, acc]) diff --git a/python/paddle/v2/fluid/tests/book/test_word2vec.py b/python/paddle/v2/fluid/tests/book/test_word2vec.py index 0629e1cab..b0cd1a518 100644 --- a/python/paddle/v2/fluid/tests/book/test_word2vec.py +++ b/python/paddle/v2/fluid/tests/book/test_word2vec.py @@ -1,10 +1,6 @@ import numpy as np import paddle.v2 as paddle -import paddle.v2.fluid.core as core -import paddle.v2.fluid.framework as framework -import paddle.v2.fluid.layers as layers -from paddle.v2.fluid.executor import Executor -from paddle.v2.fluid.optimizer import SGDOptimizer +import paddle.v2.fluid as fluid PASS_NUM = 100 EMBED_SIZE = 32 @@ -16,57 +12,57 @@ IS_SPARSE = True word_dict = paddle.dataset.imikolov.build_dict() dict_size = len(word_dict) -first_word = layers.data(name='firstw', shape=[1], dtype='int64') -second_word = layers.data(name='secondw', shape=[1], dtype='int64') -third_word = layers.data(name='thirdw', shape=[1], dtype='int64') -forth_word = layers.data(name='forthw', shape=[1], dtype='int64') -next_word = layers.data(name='nextw', shape=[1], dtype='int64') +first_word = fluid.layers.data(name='firstw', shape=[1], dtype='int64') +second_word = fluid.layers.data(name='secondw', shape=[1], dtype='int64') +third_word = fluid.layers.data(name='thirdw', shape=[1], dtype='int64') +forth_word = fluid.layers.data(name='forthw', shape=[1], dtype='int64') +next_word = fluid.layers.data(name='nextw', shape=[1], dtype='int64') -embed_first = layers.embedding( +embed_first = fluid.layers.embedding( input=first_word, size=[dict_size, EMBED_SIZE], dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) -embed_second = layers.embedding( +embed_second = fluid.layers.embedding( input=second_word, size=[dict_size, EMBED_SIZE], dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) -embed_third = layers.embedding( +embed_third = fluid.layers.embedding( input=third_word, size=[dict_size, EMBED_SIZE], dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) -embed_forth = layers.embedding( +embed_forth = fluid.layers.embedding( input=forth_word, size=[dict_size, EMBED_SIZE], dtype='float32', is_sparse=IS_SPARSE, param_attr={'name': 'shared_w'}) -concat_embed = layers.concat( +concat_embed = fluid.layers.concat( input=[embed_first, embed_second, embed_third, embed_forth], axis=1) -hidden1 = layers.fc(input=concat_embed, size=HIDDEN_SIZE, act='sigmoid') -predict_word = layers.fc(input=hidden1, size=dict_size, act='softmax') -cost = layers.cross_entropy(input=predict_word, label=next_word) -avg_cost = layers.mean(x=cost) -sgd_optimizer = SGDOptimizer(learning_rate=0.001) -opts = sgd_optimizer.minimize(avg_cost) +hidden1 = fluid.layers.fc(input=concat_embed, size=HIDDEN_SIZE, act='sigmoid') +predict_word = fluid.layers.fc(input=hidden1, size=dict_size, act='softmax') +cost = fluid.layers.cross_entropy(input=predict_word, label=next_word) +avg_cost = fluid.layers.mean(x=cost) +sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) +sgd_optimizer.minimize(avg_cost) train_reader = paddle.batch( paddle.dataset.imikolov.train(word_dict, N), BATCH_SIZE) -place = core.CPUPlace() -exe = Executor(place) +place = fluid.CPUPlace() +exe = fluid.Executor(place) # fix https://github.com/PaddlePaddle/Paddle/issues/5434 then remove # below exit line. exit(0) -exe.run(framework.default_startup_program()) +exe.run(fluid.default_startup_program()) for pass_id in range(PASS_NUM): for data in train_reader(): @@ -74,36 +70,15 @@ for pass_id in range(PASS_NUM): input_data = map(lambda x: np.array(x).astype("int64"), input_data) input_data = map(lambda x: np.expand_dims(x, axis=1), input_data) - first_data = input_data[0] - first_tensor = core.LoDTensor() - first_tensor.set(first_data, place) - - second_data = input_data[1] - second_tensor = core.LoDTensor() - second_tensor.set(second_data, place) - - third_data = input_data[2] - third_tensor = core.LoDTensor() - third_tensor.set(third_data, place) - - forth_data = input_data[3] - forth_tensor = core.LoDTensor() - forth_tensor.set(forth_data, place) - - next_data = input_data[4] - next_tensor = core.LoDTensor() - next_tensor.set(next_data, place) - - outs = exe.run(framework.default_main_program(), - feed={ - 'firstw': first_tensor, - 'secondw': second_tensor, - 'thirdw': third_tensor, - 'forthw': forth_tensor, - 'nextw': next_tensor - }, - fetch_list=[avg_cost]) - out = np.array(outs[0]) - if out[0] < 10.0: + avg_cost_np = exe.run(fluid.default_main_program(), + feed={ + 'firstw': input_data[0], + 'secondw': input_data[1], + 'thirdw': input_data[2], + 'forthw': input_data[3], + 'nextw': input_data[4] + }, + fetch_list=[avg_cost]) + if avg_cost_np[0] < 10.0: exit(0) # if avg cost less than 10.0, we think our code is good. exit(1) -- GitLab