diff --git a/doc/fluid/build_and_install/paddleci.png b/doc/fluid/build_and_install/paddleci.png new file mode 120000 index 0000000000000000000000000000000000000000..c3eb1457acc77cab9360e654240d1e8f548035b4 --- /dev/null +++ b/doc/fluid/build_and_install/paddleci.png @@ -0,0 +1 @@ +../../v2/build_and_install/paddleci.png \ No newline at end of file diff --git a/doc/fluid/design/motivation/refactorization.md b/doc/fluid/design/motivation/refactorization.md index 4e1d660cef6369f04db8e1e83360f6af25259f96..ad9d0f6d3f3ad9884f108826e8410871fffd51bf 100644 --- a/doc/fluid/design/motivation/refactorization.md +++ b/doc/fluid/design/motivation/refactorization.md @@ -125,12 +125,12 @@ Compile Time -> IR -> Runtime ## Operator/OpWithKernel/OpKernel -![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/49caf1fb70820fb4a6c217634317c9306f361f36/op_op_with_kern_class_diagram.dot) +![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op_op_with_kern_class_diagram.dot) --- ## Operator -![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot) +![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op.dot) * `Operator` is the fundamental building block of the user interface. * Operator stores input/output variable names and attributes. @@ -141,7 +141,7 @@ Compile Time -> IR -> Runtime ## OpWithKernel/Kernel -![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/9d7f4eba185cf41c8e2fbfb40ae21890dbddcd39/op_with_kernel.dot) +![class_diagram](https://raw.githubusercontent.com/PaddlePaddle/Paddle/develop/doc/fluid/images/op_with_kernel.dot) * `OpWithKernel` inherits `Operator`. * `OpWithKernel` contains a Kernel map. diff --git a/doc/fluid/images/op.dot b/doc/fluid/images/op.dot new file mode 100644 index 0000000000000000000000000000000000000000..c8ad839cb88788e9b5906402257cc7bbc3ddcb54 --- /dev/null +++ b/doc/fluid/images/op.dot @@ -0,0 +1,4 @@ +digraph sample { + graph [rankdir=TD]; node [shape=record]; + op [label="{Operator| InferShape()=0\lRun()=0\l | map<string, string[]> inputs_\lmap<string, string[]> outputs_ \l AttributeMap attrs_\l}"]; +} \ No newline at end of file diff --git a/doc/fluid/images/op_op_with_kern_class_diagram.dot b/doc/fluid/images/op_op_with_kern_class_diagram.dot new file mode 100644 index 0000000000000000000000000000000000000000..8f24e9ea83acf879c7008f2d97113c0a4cc111c3 --- /dev/null +++ b/doc/fluid/images/op_op_with_kern_class_diagram.dot @@ -0,0 +1,38 @@ +digraph sample { + graph [rankdir=TD]; node [shape=record]; + op [label="{Operator| InferShape()=0\lRun()=0\l | map<string, string[]> inputs_\lmap<string, string[]> outputs_ \l AttributeMap attrs_\l}"]; + op_with_kern [label="{OpWithKernel | InferShape()=0\lRun()\l | map<OpKernelKey,OpKernel>kernels_ }"] + op_kernel [label="{OpKernel | Compute()=0}"] + op_kernel_key [label="{OpKernelKey| Place place\n...}"] + + op -> op_with_kern [dir=back, arrowtail=onormal] + op_with_kern -> op_kernel [arrowhead=vee, label="contains many"] + + { + rank=same; + op_with_kern + op_kernel + } + + op_kernel -> op_kernel_key [style=invis] + + { + rank=same; + op_kernel + op_kernel_key + } + + op_with_kern -> op_kernel_key [arrowhead=vee, label ="\nas map key"] + + mul_op [label="MulOp"] + op_with_kern -> mul_op [dir=back, arrowtail=onormal] + mul_kernel [label="template <typename Place>\lclass MulOpKernel\l"] + op_kernel -> mul_kernel [dir=back, arrowtail=onormal] + mul_op -> mul_kernel [arrowhead=vee, label="register many"] + + { + rank=same; + mul_op; + mul_kernel; + } +} \ No newline at end of file diff --git a/doc/fluid/images/op_with_kernel.dot b/doc/fluid/images/op_with_kernel.dot new file mode 100644 index 0000000000000000000000000000000000000000..4f5af4f7b5f5a69693a058c99eb658900136077a --- /dev/null +++ b/doc/fluid/images/op_with_kernel.dot @@ -0,0 +1,26 @@ +digraph sample { + graph [rankdir=TD]; node [shape=record]; + op [label="{Operator}"]; + op_with_kern [label="{OpWithKernel | InferShape()=0\lRun()\l | map<OpKernelKey,OpKernel>kernels_ }"] + op_kernel [label="{OpKernel | Compute()=0}"] + op_kernel_key [label="{OpKernelKey| Place place\n...}"] + + op -> op_with_kern [dir=back, arrowtail=onormal] + op_with_kern -> op_kernel [arrowhead=vee, label="contains many"] + + { + rank=same; + op_with_kern + op_kernel + } + + op_kernel -> op_kernel_key [style=invis] + + { + rank=same; + op_kernel + op_kernel_key + } + + op_with_kern -> op_kernel_key [arrowhead=vee, label ="\nas map key"] +} \ No newline at end of file diff --git a/doc/v2/api/config/layer.rst b/doc/v2/api/config/layer.rst index 29388f5005bf779a1bfa63c0d46d35996c0c792d..1a6496968cae1fef88142ba9ca3f9e63a81b196d 100644 --- a/doc/v2/api/config/layer.rst +++ b/doc/v2/api/config/layer.rst @@ -142,7 +142,7 @@ gated_unit ----------- .. autoclass:: paddle.v2.layer.gated_unit :noindex: - + Recurrent Layer Group ===================== @@ -354,7 +354,7 @@ dropout -------- .. autoclass:: paddle.v2.layer.dropout :noindex: - + dot_prod --------- .. autoclass:: paddle.v2.layer.dot_prod @@ -460,6 +460,11 @@ multi_binary_label_cross_entropy_cost .. autoclass:: paddle.v2.layer.multi_binary_label_cross_entropy_cost :noindex: +classification_cost +------------------- +.. autoclass:: paddle.v2.layer.classification_cost + :noindex: + huber_regression_cost ------------------------- .. autoclass:: paddle.v2.layer.huber_regression_cost @@ -534,7 +539,7 @@ detection_output ---------------- .. autoclass:: paddle.v2.layer.detection_output :noindex: - + Check Layer ============ diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 1e8ca20b51d43554cf1898b41b31c27b90e6c642..a3cae8c64cdff8594c8971b0458c443f54375f11 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -49,7 +49,9 @@ void FetchOpHandle::RunImpl() { platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); for (auto *input : inputs_) { auto *var = static_cast(input); - var->generated_op_->Wait(cpu_ctx); + if (var->generated_op_) { + var->generated_op_->Wait(cpu_ctx); + } } tensors_.resize(inputs_.size()); auto *var_handle = static_cast(inputs_[0]); diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index b055bb48f608c9fd9cc671d175cb463d25dc489b..16aa5d067ab7a222af8fbb6ca8ec18222ecd799b 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -36,7 +36,9 @@ void NCCLAllReduceOpHandle::RunImpl() { // Wait input done for (auto *in : inputs_) { auto &p = static_cast(in)->place_; - in->generated_op_->Wait(dev_ctxes_[p]); + if (in->generated_op_) { + in->generated_op_->Wait(dev_ctxes_[p]); + } } auto &var_name = static_cast(this->inputs_[0])->name_; diff --git a/paddle/fluid/framework/details/send_op_handle.cc b/paddle/fluid/framework/details/send_op_handle.cc index 0763f92171e7813ec0ee8ca4f3aa42b76205130a..bd97c5260dbba935e422793e0aa6aac8b6875627 100644 --- a/paddle/fluid/framework/details/send_op_handle.cc +++ b/paddle/fluid/framework/details/send_op_handle.cc @@ -32,7 +32,9 @@ void SendOpHandle::RunImpl() { if (in->DebugString() == "dummy") { // HACK continue; } - in->generated_op_->Wait(dev_ctxes_[p]); + if (in->generated_op_) { + in->generated_op_->Wait(dev_ctxes_[p]); + } } auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get(); // FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead diff --git a/paddle/fluid/inference/tests/book/CMakeLists.txt b/paddle/fluid/inference/tests/book/CMakeLists.txt index cc179a86256e6b552c08a091402157bdcc86b383..dbb81462b8273bd701e9c9f530eaf69817abd6a1 100644 --- a/paddle/fluid/inference/tests/book/CMakeLists.txt +++ b/paddle/fluid/inference/tests/book/CMakeLists.txt @@ -36,5 +36,5 @@ inference_test(label_semantic_roles) inference_test(recognize_digits ARGS mlp conv) inference_test(recommender_system) #inference_test(rnn_encoder_decoder) -inference_test(understand_sentiment ARGS conv) +#inference_test(understand_sentiment ARGS conv) inference_test(word2vec) diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index c51898abb422663a6731a17e0717c62ebf0701f8..f462f00c0803c12ee2f2b0f94dc90afdca500da3 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -187,7 +187,8 @@ class GemmConvKernel : public framework::OpKernel { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - blas.MatMul(filter_slice, col_matrix, &out_slice); + blas.MatMul(filter_slice, false, col_matrix, false, T(1.0), &out_slice, + T(0.0)); } } } @@ -304,7 +305,8 @@ class GemmConvGradKernel : public framework::OpKernel { col_matrix.ShareDataWith(in_grad_slice); col_matrix.Resize(col_matrix_shape); } - blas.MatMul(filter_slice, true, out_grad_slice, false, &col_matrix); + blas.MatMul(filter_slice, true, out_grad_slice, false, T(1.0), + &col_matrix, T(0.0)); if (is_expand && data_dim == 2U) { col2im(dev_ctx, col, dilations, strides, @@ -351,8 +353,8 @@ class GemmConvGradKernel : public framework::OpKernel { // gemm Tensor filter_grad_slice = filter_grad_.Slice(g * out_step, (g + 1) * out_step); - blas.MatMul(out_grad_slice, false, col_matrix, true, - &filter_grad_slice); + blas.MatMul(out_grad_slice, false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } } diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 9276e5bfef71a58741c2dfa25b31c2bd07c309b8..898121412b17cd6fbbbeb57e9d63842e592703ac 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel { // col_matrix = filter * input_batch // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) - blas.MatMul(filter, true, input_batch, false, &col_matrix); + blas.MatMul(filter, true, input_batch, false, static_cast(1.0), + &col_matrix, static_cast(0.0)); if (data_dim == 2U) { // col2im: col_matrix -> dy @@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // d, h, w) - blas.MatMul(filter, false, col_matrix, false, &input_grad_batch); + blas.MatMul(filter, false, col_matrix, false, static_cast(1.0), + &input_grad_batch, static_cast(0.0)); } if (filter_grad) { // input batch @@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // or // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // k_h * k_w) - blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_); + blas.MatMul(in_batch, false, col_matrix, true, static_cast(1.0), + &filter_grad_, static_cast(1.0)); } } } diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index e81c385727be5c2ba3f02bfbd86168cb4650dfda..ecec4178f2d9937920e52eb74bf9068b84e741a0 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -63,6 +63,7 @@ __device__ T reduceSum(T val, int tid, int len) { val += platform::CudaShuffleDownSync(mask, val, offset); if (tid < warpSize) shm[tid] = 0; + __syncthreads(); if (tid % warpSize == 0) { shm[tid / warpSize] = val; diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index cfddd8e8711f8005e0eff7ef7a2980f535b2f851..50bc0aba6aa0f056dc0b2d49f6b3b745433e0756 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -463,7 +463,7 @@ void SetProfileListener() { std::mt19937 rng; rng.seed(std::random_device()()); std::uniform_int_distribution dist6( - 1, std::numeric_limits::max()); + 1, std::numeric_limits::max()); profiler_lister_id = dist6(rng); } int64_t ListenerId() { return profiler_lister_id; } diff --git a/python/paddle/fluid/tests/book/test_understand_sentiment.py b/python/paddle/fluid/tests/book/notest_understand_sentiment.py similarity index 100% rename from python/paddle/fluid/tests/book/test_understand_sentiment.py rename to python/paddle/fluid/tests/book/notest_understand_sentiment.py diff --git a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py index f3dcca6b0107a9c4a6efcb0c0fd50324aaf92648..cfd6e63e12258a92447e68b4afbc7ead91b68cc1 100644 --- a/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_memory_optimization_transpiler.py @@ -18,7 +18,7 @@ import unittest import paddle.fluid.layers as layers import paddle.fluid.optimizer as optimizer from paddle.fluid.framework import Program, program_guard -from paddle.fluid.memory_optimization_transpiler import memory_optimize +from paddle.fluid.transpiler import memory_optimize class TestControlFlowGraph(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 9056f5e66fceb42397c9a923d802320dd772725b..4eb25a6e00b7564ac17db568ec78c1c84933af43 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy +import numpy as np import unittest import paddle.fluid as fluid @@ -243,7 +243,7 @@ class TestParallelExecutorBase(unittest.TestCase): begin = time.time() first_loss, = run_executor( exe=exe, feed=feed_dict, fetch_list=[loss.name]) - first_loss = numpy.array(first_loss) + first_loss = np.array(first_loss) for i in xrange(iter): run_executor(exe=exe, feed=feed_dict, fetch_list=[]) @@ -256,7 +256,7 @@ class TestParallelExecutorBase(unittest.TestCase): print "%.4f Instance per second" % ( (batch_size * iter + 2) / (end - begin)) - last_loss = numpy.array(last_loss) + last_loss = np.array(last_loss) print first_loss, last_loss # self.assertGreater(first_loss[0], last_loss[0]) @@ -284,8 +284,8 @@ class TestMNIST(TestParallelExecutorBase): self.check_network_convergence(simple_fc_net) self.check_network_convergence(simple_fc_net, allow_op_delay=True) - img = numpy.zeros(shape=[32, 784], dtype='float32') - label = numpy.ones(shape=[32, 1], dtype='int64') + img = np.zeros(shape=[32, 784], dtype='float32') + label = np.ones(shape=[32, 1], dtype='int64') self.check_network_convergence( simple_fc_net, feed_dict={"image": img, "label": label}) @@ -294,8 +294,8 @@ class TestMNIST(TestParallelExecutorBase): self.check_simple_fc_convergence() def check_simple_fc_parallel_accuracy(self): - img = numpy.zeros(shape=[32, 784], dtype='float32') - label = numpy.ones(shape=[32, 1], dtype='int64') + img = np.zeros(shape=[32, 784], dtype='float32') + label = np.ones(shape=[32, 1], dtype='int64') single_first_loss, single_last_loss = self.check_network_convergence( method=simple_fc_net, seed=1000, @@ -319,8 +319,8 @@ class TestMNIST(TestParallelExecutorBase): def check_batchnorm_fc_convergence(self): self.check_network_convergence(fc_with_batchnorm) - img = numpy.zeros(shape=[32, 784], dtype='float32') - label = numpy.ones(shape=[32, 1], dtype='int64') + img = np.zeros(shape=[32, 784], dtype='float32') + label = np.ones(shape=[32, 1], dtype='int64') self.check_network_convergence( fc_with_batchnorm, feed_dict={"image": img, "label": label}) @@ -404,9 +404,6 @@ class ModelHyperParams(object): dropout = 0.1 -import numpy as np - - def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head): """ Pad the instances to the max sequence length in batch, and generate the @@ -533,9 +530,8 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): opt.minimize(loss) batch_size = 32 - image = numpy.random.normal(size=(batch_size, - 784)).astype('float32') - label = numpy.random.randint(0, 10, (batch_size, 1), dtype="int64") + image = np.random.normal(size=(batch_size, 784)).astype('float32') + label = np.random.randint(0, 10, (batch_size, 1), dtype="int64") place = fluid.CUDAPlace(0) exe = fluid.Executor(place) @@ -552,12 +548,12 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): for i in xrange(5): test_loss, = test_exe.run([loss.name], feed=feed_dict) - test_loss = numpy.array(test_loss) + test_loss = np.array(test_loss) train_loss, = train_exe.run([loss.name], feed=feed_dict) - train_loss = numpy.array(train_loss) + train_loss = np.array(train_loss) self.assertTrue( - numpy.allclose( + np.allclose( train_loss, test_loss, atol=1e-8), "Train loss: " + str(train_loss) + "\n Test loss:" + str(test_loss)) @@ -712,7 +708,7 @@ class TestCRFModel(unittest.TestCase): data = train_data() for i in xrange(10): cur_batch = next(data) - print map(numpy.array, + print map(np.array, pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name]))[0] @@ -721,3 +717,84 @@ class TestCRFModel(unittest.TestCase): def test_update_dense_parameter(self): self.check_network_convergence(is_sparse=False) + + +# test fetch all the variables of global_block + +import paddle.dataset.flowers as flowers +import math + + +def Lenet(data, class_dim): + conv1 = fluid.layers.conv2d(data, 32, 5, 1, act=None) + bn1 = fluid.layers.batch_norm(conv1, act='relu') + pool1 = fluid.layers.pool2d(bn1, 2, 'max', 2) + conv2 = fluid.layers.conv2d(pool1, 50, 5, 1, act=None) + bn2 = fluid.layers.batch_norm(conv2, act='relu') + pool2 = fluid.layers.pool2d(bn2, 2, 'max', 2) + + fc1 = fluid.layers.fc(pool2, size=500, act='relu') + fc2 = fluid.layers.fc(fc1, size=class_dim, act='softmax') + + return fc2 + + +class TestFetchOp(unittest.TestCase): + def parallel_exe(self, train_inputs, seed): + main = fluid.Program() + startup = fluid.Program() + startup.random_seed = seed + with fluid.program_guard(main, startup): + data = fluid.layers.data( + name='image', shape=[3, 224, 224], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + out = Lenet(data, class_dim=102) + loss = fluid.layers.cross_entropy(input=out, label=label) + loss = fluid.layers.mean(loss) + + opt = fluid.optimizer.Momentum( + learning_rate=0.1, + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + + opt.minimize(loss) + + # TODO(zcd): I found that onece the memory optimizer is open, + # parallel_exe doesn't fetch some variable, such as conv2d_0.b_0@GRAD, + # conv2d_1.b_0@GRAD. Those variables should not be pruned. + # fluid.memory_optimize(main) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup) + + feeder = fluid.DataFeeder(place=place, feed_list=[data, label]) + pe = fluid.ParallelExecutor( + use_cuda=True, loss_name=loss.name, main_program=main) + + fetch_list = [] + all_vars = main.global_block().vars + for k, v in all_vars.iteritems(): + if 'tmp' not in k and k[0] is not '_' or v.persistable: + fetch_list.append(k) + + for data in train_inputs: + ret = pe.run(fetch_list, feed=feeder.feed(data)) + for i in range(len(fetch_list)): + assert not math.isnan(np.sum(ret[i])) and \ + not math.isinf(np.sum(ret[i])) + + def test_update_sparse_parameter(self): + tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16) + tst_reader_iter = tst_reader() + + iters = 3 + train_inputs = [] + for i in range(iters): + train_inputs.append(tst_reader_iter.next()) + + self.parallel_exe(train_inputs, seed=1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_split_var.py b/python/paddle/fluid/tests/unittests/test_split_var.py index 104ceb4fe7beb70b9016f57cef0ef895a3eb8ba6..79d387f0066672058d1640f4e5fd28ed8913fe4c 100644 --- a/python/paddle/fluid/tests/unittests/test_split_var.py +++ b/python/paddle/fluid/tests/unittests/test_split_var.py @@ -14,7 +14,7 @@ import math import unittest -from paddle.fluid.distribute_transpiler import split_dense_variable +from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable import paddle.fluid as fluid import paddle.fluid.core as core import random