diff --git a/benchmark/fluid/mnist.py b/benchmark/fluid/mnist.py index 1e2185dfac1072d1f1046f4616a9d53a8fc76061..400200c4745017bd9d160bb9e415fde041c0a6c8 100644 --- a/benchmark/fluid/mnist.py +++ b/benchmark/fluid/mnist.py @@ -159,6 +159,7 @@ def run_benchmark(model, args): paddle.dataset.mnist.train(), batch_size=args.batch_size) accuracy = fluid.metrics.Accuracy() + train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) iters, num_samples, start_time = 0, 0, time.time() for pass_id in range(args.pass_num): accuracy.reset() @@ -175,17 +176,20 @@ def run_benchmark(model, args): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([len(y_data), 1]) - outs = exe.run( - fluid.default_main_program(), + outs = train_exe.run( feed={"pixel": img_data, "label": y_data}, - fetch_list=[avg_cost, batch_acc, batch_size_tensor] + fetch_list=[ + avg_cost.name, batch_acc.name, batch_size_tensor.name + ] ) # The accuracy is the accumulation of batches, but not the current batch. - accuracy.update(value=outs[1], weight=outs[2]) + accuracy.update( + value=np.array(np.mean(outs[1])), + weight=np.mean(np.array(outs[2]))) iters += 1 num_samples += len(y_data) - loss = np.array(outs[0]) - acc = np.array(outs[1]) + loss = np.mean(np.array(outs[0])) + acc = np.mean(np.array(outs[1])) train_losses.append(loss) train_accs.append(acc) print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" % diff --git a/benchmark/fluid/resnet.py b/benchmark/fluid/resnet.py index 831fa2c019fc2868cd85b1ca7b2c8c76a2f1628c..0fd7258a804e7c93b0b03da140140394bf90004a 100644 --- a/benchmark/fluid/resnet.py +++ b/benchmark/fluid/resnet.py @@ -241,6 +241,7 @@ def run_benchmark(model, args): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) accuracy = fluid.average.WeightedAverage() + train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) if args.use_fake_data: data = train_reader().next() image = np.array(map(lambda x: x[0].reshape(dshape), data)).astype( @@ -264,14 +265,17 @@ def run_benchmark(model, args): data)).astype('float32') label = np.array(map(lambda x: x[1], data)).astype('int64') label = label.reshape([-1, 1]) - loss, acc, weight = exe.run( - fluid.default_main_program(), + loss, acc, weight = train_exe.run( feed={'data': image, 'label': label}, - fetch_list=[avg_cost, batch_acc, batch_size_tensor]) + fetch_list=[ + avg_cost.name, batch_acc.name, batch_size_tensor.name + ]) iters += 1 num_samples += len(label) - accuracy.add(value=acc, weight=weight) + accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight)) + loss = np.mean(np.array(loss)) + acc = np.mean(np.array(acc)) train_losses.append(loss) train_accs.append(acc) print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" % diff --git a/benchmark/fluid/vgg.py b/benchmark/fluid/vgg.py index 53e34e0cbd15914791c305db6797f826ebfae34e..2a9566a45c3804183e05db9298cec4f670225a6f 100644 --- a/benchmark/fluid/vgg.py +++ b/benchmark/fluid/vgg.py @@ -169,6 +169,7 @@ def main(): iters, num_samples, start_time = 0, 0, time.time() accuracy = fluid.average.WeightedAverage() + train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) for pass_id in range(args.pass_num): accuracy.reset() train_accs = [] @@ -184,14 +185,17 @@ def main(): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - loss, acc, weight = exe.run( - fluid.default_main_program(), + loss, acc, weight = train_exe.run( feed={"pixel": img_data, "label": y_data}, - fetch_list=[avg_cost, batch_acc, batch_size_tensor]) - accuracy.add(value=acc, weight=weight) + fetch_list=[ + avg_cost.name, batch_acc.name, batch_size_tensor.name + ]) + accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight)) iters += 1 num_samples += len(y_data) + loss = np.mean(np.array(loss)) + acc = np.mean(np.array(acc)) print( "Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" % (pass_id, iters, loss, acc) diff --git a/cmake/external/snappy.cmake b/cmake/external/snappy.cmake index 80282329c6ac65fbd1493a6838efca4bd9cadaad..af09ed4d5d6e21cc50aba5198a7e9ea56f49451a 100644 --- a/cmake/external/snappy.cmake +++ b/cmake/external/snappy.cmake @@ -47,8 +47,6 @@ ExternalProject_Add( -DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPY_INSTALL_DIR}/lib -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} - BUILD_COMMAND make -j8 - INSTALL_COMMAND make install ) add_library(snappy STATIC IMPORTED GLOBAL) diff --git a/cmake/external/snappystream.cmake b/cmake/external/snappystream.cmake index 20a96430823d07a07d4bb4602e7fc0cfe55c3bf2..6df636d7fa8757ade73892bda03a80ba9767472b 100644 --- a/cmake/external/snappystream.cmake +++ b/cmake/external/snappystream.cmake @@ -46,8 +46,6 @@ ExternalProject_Add( -DCMAKE_INSTALL_PREFIX:PATH=${SNAPPYSTREAM_INSTALL_DIR} -DCMAKE_INSTALL_LIBDIR:PATH=${SNAPPYSTREAM_INSTALL_DIR}/lib -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} - BUILD_COMMAND make -j8 - INSTALL_COMMAND make install DEPENDS snappy ) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index 06a7ae56827d5afe857ed0a98092210917a52430..807a48a41f72f17944dc1be5b793b0ca7d70c527 100644 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -148,4 +148,10 @@ copy(string_lib DSTS ${dst_dir}/${module} ${dst_dir}/${module}/tinyformat ) +set(module "pybind") +copy(pybind_lib + SRCS ${CMAKE_CURRENT_BINARY_DIR}/paddle/fluid/${module}/pybind.h + DSTS ${dst_dir}/${module} +) + add_custom_target(inference_lib_dist DEPENDS ${inference_lib_dist_dep}) diff --git a/doc/fluid/design/concepts/functions_operators_layers.md b/doc/fluid/design/concepts/functions_operators_layers.md index 30bc488a18a28d349645d9d2502aae6691a69931..1f86b99e5197c3e0b85fd76fe704520ef21b06d3 100644 --- a/doc/fluid/design/concepts/functions_operators_layers.md +++ b/doc/fluid/design/concepts/functions_operators_layers.md @@ -40,7 +40,7 @@ template class FCOp : public OperatorBase { public: void Run(...) { - add(mul(Input("X"), Input("W")), Input("b"); + add(mul(Input("X"), Input("W")), Input("b")); } }; REGISTER_OP(FCOp, "fc"); diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index 8b1ca5e16548334ed0c9a6d31b88e0805304579e..d722eec1892206ac44c49e7a12d92be0c54df8c0 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -24,6 +24,6 @@ if(NOT WITH_FLUID_ONLY) endif() add_subdirectory(testing) -if(NOT MOBILE_INFERENCE AND NOT RPI) +if(NOT MOBILE_INFERENCE AND NOT RPI AND NOT WITH_C_API) add_subdirectory(fluid) endif() diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc index b9c90cb0c32f337ba82ce1eaa5b43199540491ef..b6b93cf422a60c1d8e9cb8b477efd562f9fe4758 100644 --- a/paddle/fluid/framework/data_type.cc +++ b/paddle/fluid/framework/data_type.cc @@ -58,6 +58,7 @@ static DataTypeMap* InitDataTypeMap() { RegType(bool, proto::VarType::BOOL); RegType(size_t, proto::VarType::SIZE_T); RegType(int16_t, proto::VarType::INT16); + RegType(uint8_t, proto::VarType::UINT8); #undef RegType return retv; diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 4b9f572ec5f1cda71c8b8dd8fae54b42e9f16f7a..491413db8c8d66fd907801131e89d9303bdef9f2 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -47,8 +47,14 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) { case proto::VarType::BOOL: visitor.template operator()(); break; + case proto::VarType::UINT8: + visitor.template operator()(); + break; + case proto::VarType::INT16: + visitor.template operator()(); + break; default: - PADDLE_THROW("Not supported"); + PADDLE_THROW("Not supported %d", type); } } diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index b1c9dd0d15223f7d1bf6ea44144589f1de927e3e..224e8e1f6efd7a894591ac51c929517cae7539ce 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -48,17 +48,18 @@ void FetchOpHandle::RunImpl() { WaitInputVarGenerated(platform::CPUPlace()); tensors_.resize(inputs_.size()); - auto *var_handle = static_cast(inputs_[0]); - auto &var_name = var_handle->name_; platform::CPUPlace cpu; auto &scopes = *local_scopes_; - for (size_t i = 0; i < scopes.size(); ++i) { - auto &scope = scopes[i]; - auto *var = - scope->FindVar(kLocalExecScopeName)->Get()->FindVar(var_name); + for (size_t i = 0; i < inputs_.size(); ++i) { + auto *var_handle = static_cast(inputs_[i]); + auto &scope = scopes.at(var_handle->scope_idx_); + auto *var = scope->FindVar(kLocalExecScopeName) + ->Get() + ->FindVar(var_handle->name_); PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", - var_name); + var_handle->name_); + auto &t = var->Get(); if (platform::is_gpu_place(t.place())) { #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index fe1735d05dde5f09d5c72c68e5002d16f0083eb5..8f94206a87dbae8a81727ca48718886bbabbe25c 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -70,6 +70,14 @@ class OpHandleBase { const std::vector &Inputs() const { return inputs_; } + size_t NoDupInputSize() const { + std::unordered_set res; + for (auto *var : inputs_) { + res.emplace(var); + } + return res.size(); + } + const std::vector &Outputs() const { return outputs_; } protected: diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ef263d82c5ec93f0673eb0ac70e4fb02904bff13..815f739371e77d953a28be99b38ec1b8ff26506c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -174,7 +174,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertPendingOp( std::unordered_map *pending_ops, OpHandleBase *op_instance) const { - pending_ops->insert({op_instance, op_instance->Inputs().size()}); + pending_ops->insert({op_instance, op_instance->NoDupInputSize()}); } void ThreadedSSAGraphExecutor::InsertPendingVar( diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index d2558f111f49139b33f921f7260b41830279edc8..d35125fe8c3c8018c38650dc87b2b1474ded6058 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -103,6 +103,7 @@ message VarType { FP64 = 6; // Tensor is used in C++. SIZE_T = 19; + UINT8 = 20; // Other types that may need additional descriptions LOD_TENSOR = 7; diff --git a/paddle/fluid/framework/lod_tensor_test.cc b/paddle/fluid/framework/lod_tensor_test.cc index 77e5ec4c7dd14b7ebb6d606b8c401ee714259d40..2ceffc93319359683e87e7fec2d18784c9bf02f3 100644 --- a/paddle/fluid/framework/lod_tensor_test.cc +++ b/paddle/fluid/framework/lod_tensor_test.cc @@ -228,11 +228,12 @@ TEST(LoD, CheckAbsLoD) { ASSERT_FALSE(CheckAbsLoD(abs_lod0)); } -TEST(LoDTensor, RecordIO) { +template +static void TestRecordIO() { LoDTensor tensor; - int* tmp = tensor.mutable_data(make_ddim({4, 5}), platform::CPUPlace()); + T* tmp = tensor.mutable_data(make_ddim({4, 5}), platform::CPUPlace()); for (int i = 0; i < 20; ++i) { - tmp[i] = i; + tmp[i] = static_cast(i); } std::stringstream* stream = new std::stringstream(); @@ -247,7 +248,7 @@ TEST(LoDTensor, RecordIO) { auto assert_tensor_ok = [](const LoDTensor& tensor) { for (int i = 0; i < 20; ++i) { - ASSERT_EQ(tensor.data()[i], i); + ASSERT_EQ(tensor.data()[i], static_cast(i)); } }; @@ -265,5 +266,13 @@ TEST(LoDTensor, RecordIO) { } } +TEST(LoDTensor, RecordIO) { + TestRecordIO(); + TestRecordIO(); + TestRecordIO(); + TestRecordIO(); + TestRecordIO(); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index a4eb6f706edab9479cbce436311eb96da8845646..2f480e00c100d579e100de17d3feb957f5ef6167 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -33,7 +33,6 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/variant.h" -#include "paddle/utils/Error.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index abc9ebf472498f6653d5bb1113ae2f3ce7e5a923..1cd3ed9a00acead2599420f88499bd0d74c2974b 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -49,7 +49,7 @@ class OpConverter { // convert fluid block to tensorrt network void ConvertBlock(const framework::proto::BlockDesc& block, TensorRTEngine* engine) { - for (size_t i = 0; i < block.ops_size(); i++) { + for (int i = 0; i < block.ops_size(); i++) { const auto& op = block.ops(i); OpConverter::Run(op, engine); } diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index d62ea387cc55c7399973b6f35bace491a49666dc..d39154c6f88d6d17c1719eb9a5b048211f4bb52b 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -38,7 +38,9 @@ template struct SetConstant; template struct Transpose; \ template struct Transpose; \ template struct Transpose; \ - template struct Transpose; + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; DEFINE_CPU_TRANS(1); DEFINE_CPU_TRANS(2); diff --git a/paddle/fluid/operators/smooth_l1_loss_op.cc b/paddle/fluid/operators/smooth_l1_loss_op.cc index c44c5f164b2d84616e9a85813e0ee5219b41df28..622420c1c33a62994c81ad9534c4fa37a4a1fa1a 100644 --- a/paddle/fluid/operators/smooth_l1_loss_op.cc +++ b/paddle/fluid/operators/smooth_l1_loss_op.cc @@ -105,7 +105,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - auto in_dims = ctx->GetInputDim("X"); + auto in_dims = ctx->GetInputDim("Diff"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); PADDLE_ENFORCE_GE(out_dims.size(), 2, @@ -127,12 +127,33 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { } }; +class SmoothL1LossGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new framework::OpDesc(); + op->SetType("smooth_l1_loss_grad"); + op->SetInput("InsideWeight", Input("InsideWeight")); + op->SetInput("OutsideWeight", Input("OutsideWeight")); + op->SetInput("Diff", Output("Diff")); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + + op->SetAttrMap(Attrs()); + + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), InputGrad("Y")); + return std::unique_ptr(op); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(smooth_l1_loss, ops::SmoothL1LossOp, ops::SmoothL1LossOpMaker, - paddle::framework::DefaultGradOpDescMaker); + ops::SmoothL1LossGradMaker); REGISTER_OPERATOR(smooth_l1_loss_grad, ops::SmoothL1LossGradOp); REGISTER_OP_CPU_KERNEL( smooth_l1_loss, diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 598fd4d419078a973647f2f8f20e8a12c8115a8b..79e3c26fef51b4d27520a8079de1074d72f89617 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -1,4 +1,4 @@ -proto_library(profiler_proto SRCS profiler.proto) +proto_library(profiler_proto SRCS profiler.proto DEPS framework_proto) py_proto_compile(profiler_py_proto SRCS profiler.proto) add_custom_target(profiler_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 5bef232cd8fc44ded89ac56a790c8db0955b390a..928b95b4f5aaec9ac2fc3e2302ef63a5eb1dcf74 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -480,6 +480,7 @@ function main() { build) cmake_gen ${PYTHON_ABI:-""} build + gen_dockerfile ;; build_android) build_android @@ -504,6 +505,7 @@ function main() { ;; capi) cmake_gen ${PYTHON_ABI:-""} + build gen_capi_package ;; fluid_inference_lib) diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index 0051b698471b40bffc12921f86dcde642714e07d..a44e078d0c13717643a6cfc6dd8bff5901ee9c97 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -54,9 +54,9 @@ class DataToLoDTensorConverter(object): self.data.append(data) else: cur_lod_len = len(data) - lod[-1].append(lod[-1][-1] + cur_lod_len) + lod[0].append(lod[0][-1] + cur_lod_len) for each_data in data: - self._feed_impl_(each_data, lod[:-1], lod_level - 1) + self._feed_impl_(each_data, lod[1:], lod_level - 1) def done(self): arr = numpy.array(self.data, dtype=self.dtype).reshape(self.shape) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1786be22fdcd0d074b45bc94b3b0c4e8c41b4e8a..561c8bd42f90911bf5a0c898fe01412d42d2c9b1 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -1329,6 +1329,8 @@ def sequence_pool(input, pool_type): sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) + last : out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) + first : out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1) Args: input(variable): The input variable which is a LoDTensor. @@ -1348,6 +1350,8 @@ def sequence_pool(input, pool_type): sum_x = fluid.layers.sequence_pool(input=x, pool_type='sum') sqrt_x = fluid.layers.sequence_pool(input=x, pool_type='sqrt') max_x = fluid.layers.sequence_pool(input=x, pool_type='max') + last_x = fluid.layers.sequence_pool(input=x, pool_type='last') + first_x = fluid.layers.sequence_pool(input=x, pool_type='first') """ helper = LayerHelper('sequence_pool', **locals()) dtype = helper.input_dtype() @@ -3263,35 +3267,35 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): """ **Smooth L1 Loss Operator. ** - This operator computes the smooth l1 loss for X and Y. + This operator computes the smooth L1 loss for X and Y. The operator takes the first dimension of X and Y as batch size. - For each instance, it computes the smooth l1 loss element by element first + For each instance, it computes the smooth L1 loss element by element first and then sums all the losses. So the shape of Out is [batch_size, 1]. Args: x (Variable): A tensor with rank at least 2. The input value of smooth - l1 loss op with shape [batch_size, dim1, ..., dimN]. + L1 loss op with shape [batch_size, dim1, ..., dimN]. y (Variable): A tensor with rank at least 2. The target value of smooth - l1 loss op with same shape as x. + L1 loss op with same shape as x. inside_weight (Variable|None): A tensor with rank at least 2. This input is optional and should have same shape with x. If provided, the result of (x - y) will be multiplied by this tensor element by element. outside_weight (Variable|None): A tensor with rank at least 2. This input is optional and should have same shape with x. If provided, - the out smooth l1 loss will be multiplied by this tensor element + the out smooth L1 loss will be multiplied by this tensor element by element. - sigma (float|None): Hyper parameter of smooth l1 loss op. A float scalar + sigma (float|None): Hyper parameter of smooth L1 loss op. A float scalar with default value 1.0. Returns: - Variable: A tensor with rank be 2. The output smooth l1 loss with + Variable: A tensor with rank be 2. The output smooth L1 loss with shape [batch_size, 1]. Examples: .. code-block:: python data = fluid.layers.data(name='data', shape=[128], dtype='float32') - label = fluid.layers.data(name='label', shape=[100], dtype='int64') + label = fluid.layers.data(name='label', shape=[100], dtype='float32') fc = fluid.layers.fc(input=data, size=100) out = fluid.layers.smooth_l1(x=fc, y=label) """ @@ -3769,13 +3773,13 @@ def label_smooth(label, def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): """ - Region of interest pooling (also known as RoI pooling) is to perform + Region of interest pooling (also known as RoI pooling) is to perform is to perform max pooling on inputs of nonuniform sizes to obtain fixed-size feature maps (e.g. 7*7). - The operator has three steps: - 1. Dividing each region proposal into equal-sized sections with - the pooled_width and pooled_height - 2. Finding the largest value in each section + The operator has three steps: + 1. Dividing each region proposal into equal-sized sections with + the pooled_width and pooled_height + 2. Finding the largest value in each section 3. Copying these max values to the output buffer Args: @@ -3783,8 +3787,8 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): rois (Variable): ROIs (Regions of Interest) to pool over. It should be a 2-D one level LoTensor of shape [num_rois, 4]. The layout is [x1, y1, x2, y2], where (x1, y1) - is the top left coordinates, and (x2, y2) is the - bottom right coordinates. The num_rois is the + is the top left coordinates, and (x2, y2) is the + bottom right coordinates. The num_rois is the total number of ROIs in this batch data. pooled_height (integer): The pooled output height. Default: 1 pooled_width (integer): The pooled output width. Default: 1 @@ -3793,11 +3797,11 @@ def roi_pool(input, rois, pooled_height=1, pooled_width=1, spatial_scale=1.0): to the scale used when pooling. Default: 1.0 Returns: - pool_out (Variable): The output is a 4-D tensor of the shape + pool_out (Variable): The output is a 4-D tensor of the shape (num_rois, channels, pooled_h, pooled_w). Examples: - pool_out = fluid.layers.roi_pool(input=x, rois=rois, 7, 7, 1.0) + pool_out = fluid.layers.roi_pool(input=x, rois=rois, 7, 7, 1.0) """ helper = LayerHelper('roi_pool', **locals()) dtype = helper.input_dtype() diff --git a/python/paddle/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/fluid/tests/book/test_label_semantic_roles.py index 09793760e5504c04ad4b0bfac5c5d7b7047cf85d..f1ee5dfd99e1c8b26280c010c1aaca05a004a5b6 100644 --- a/python/paddle/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/fluid/tests/book/test_label_semantic_roles.py @@ -182,12 +182,6 @@ def train(use_cuda, save_dirname=None, is_local=True): crf_decode = fluid.layers.crf_decoding( input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) - chunk_evaluator = fluid.evaluator.ChunkEvaluator( - input=crf_decode, - label=target, - chunk_scheme="IOB", - num_chunk_types=int(math.ceil((label_dict_len - 1) / 2.0))) - train_data = paddle.batch( paddle.reader.shuffle( paddle.dataset.conll05.test(), buf_size=8192), @@ -203,7 +197,6 @@ def train(use_cuda, save_dirname=None, is_local=True): def train_loop(main_program): exe.run(fluid.default_startup_program()) - embedding_param = fluid.global_scope().find_var( embedding_name).get_tensor() embedding_param.set( @@ -213,27 +206,19 @@ def train(use_cuda, save_dirname=None, is_local=True): start_time = time.time() batch_id = 0 for pass_id in xrange(PASS_NUM): - chunk_evaluator.reset(exe) for data in train_data(): - cost, precision, recall, f1_score = exe.run( - main_program, - feed=feeder.feed(data), - fetch_list=[avg_cost] + chunk_evaluator.metrics) - pass_precision, pass_recall, pass_f1_score = chunk_evaluator.eval( - exe) + cost = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[avg_cost]) + cost = cost[0] if batch_id % 10 == 0: - print("avg_cost:" + str(cost) + " precision:" + str( - precision) + " recall:" + str(recall) + " f1_score:" + - str(f1_score) + " pass_precision:" + str( - pass_precision) + " pass_recall:" + str( - pass_recall) + " pass_f1_score:" + str( - pass_f1_score)) + print("avg_cost:" + str(cost)) if batch_id != 0: print("second per batch: " + str((time.time( ) - start_time) / batch_id)) # Set the threshold low to speed up the CI test - if float(pass_precision) > 0.01: + if float(cost) < 60.0: if save_dirname is not None: # TODO(liuyiqun): Change the target to crf_decode fluid.io.save_inference_model(save_dirname, [ diff --git a/python/paddle/fluid/tests/test_data_feeder.py b/python/paddle/fluid/tests/test_data_feeder.py index 861dd3174a21d59fe12e0b794ecb2a934946ac71..ce3ba3ebc50d7b015f379b5e80b179463a7b231a 100644 --- a/python/paddle/fluid/tests/test_data_feeder.py +++ b/python/paddle/fluid/tests/test_data_feeder.py @@ -13,15 +13,62 @@ # limitations under the License. import paddle.fluid as fluid +import unittest -def test_converter(): - img = fluid.layers.data(name='image', shape=[1, 28, 28]) - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - feeder = fluid.DataFeeder([img, label], fluid.CPUPlace()) - result = feeder.feed([[[0] * 784, [9]], [[1] * 784, [1]]]) - print(result) +class TestDataFeeder(unittest.TestCase): + def test_lod_level_0_converter(self): + img = fluid.layers.data(name='image', shape=[1, 28, 28]) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + feeder = fluid.DataFeeder([img, label], fluid.CPUPlace()) + result = feeder.feed([([0] * 784, [9]), ([1] * 784, [1])]) + print(result) + + self.assertEqual(result['image'].shape(), [2, 1, 28, 28]) + self.assertEqual(result['label'].shape(), [2, 1]) + self.assertEqual(result['image'].lod(), []) + self.assertEqual(result['label'].lod(), []) + + def test_lod_level_1_converter(self): + # lod_level = 1 + # each sentence has a different number of words + sentences = fluid.layers.data( + name='sentences', shape=[1], dtype='int64', lod_level=1) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + feeder = fluid.DataFeeder([sentences, label], fluid.CPUPlace()) + + # lod = [[0, 3, 5, 9]] + # data = [[1, 2, 3], [4, 5], [6, 7, 8, 9]] + # label = [1] * len(data) + result = feeder.feed( + [([1, 2, 3], [1]), ([4, 5], [1]), ([6, 7, 8, 9], [1])]) + print(result) + + self.assertEqual(result['sentences'].shape(), [9, 1]) + self.assertEqual(result['label'].shape(), [3, 1]) + self.assertEqual(result['sentences'].lod(), [[0, 3, 5, 9]]) + self.assertEqual(result['label'].lod(), []) + + def test_lod_level_2_converter(self): + # lod_level = 2 + # paragraphs -> sentences -> words + paragraphs = fluid.layers.data( + name='paragraphs', shape=[1], dtype='int64', lod_level=2) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + feeder = fluid.DataFeeder([paragraphs, label], fluid.CPUPlace()) + + # lod = [[0, 2, 3], [0, 3, 5, 9]] + # data = [[[1, 2, 3], [4, 5]], [[6, 7, 8, 9]]] + # label = [1] * len(data) + result = feeder.feed( + [([[1, 2, 3], [4, 5]], [1]), ([[6, 7, 8, 9]], [1])]) + print(result) + + self.assertEqual(result['paragraphs'].shape(), [9, 1]) + self.assertEqual(result['label'].shape(), [2, 1]) + self.assertEqual(result['paragraphs'].lod(), [[0, 2, 3], [0, 3, 5, 9]]) + self.assertEqual(result['label'].lod(), []) if __name__ == '__main__': - test_converter() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d9190408e151283ece8460286dd67818dd39da3e..2ae9653953c2f5f6a399243bef2c7fb756f9692f 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -28,11 +28,11 @@ function(py_test_modules TARGET_NAME) if(WITH_TESTING) set(options "") set(oneValueArgs "") - set(multiValueArgs MODULES DEPS ARGS ENVS) + set(multiValueArgs MODULES DEPS ENVS) cmake_parse_arguments(py_test_modules "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_test(NAME ${TARGET_NAME} COMMAND env PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_test_modules_ENVS} - ${PYTHON_EXECUTABLE} -u -m unittest --verbose ${py_test_modules_MODULES} ${py_test_modules_ARGS} + ${PYTHON_EXECUTABLE} ${PADDLE_SOURCE_DIR}/tools/test_runner.py ${py_test_modules_MODULES} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) endif() endfunction() @@ -66,6 +66,7 @@ list(REMOVE_ITEM TEST_OPS test_fetch_var) list(REMOVE_ITEM TEST_OPS test_parallel_op) list(REMOVE_ITEM TEST_OPS test_dynrnn_static_input) list(REMOVE_ITEM TEST_OPS test_dist_train) +list(REMOVE_ITEM TEST_OPS test_network_with_dtype) # tests that can be bundled together in one python process for speed. if(WITH_FAST_BUNDLE_TEST) @@ -83,6 +84,7 @@ py_test_modules(test_parallel_executor MODULES test_parallel_executor) py_test_modules(test_warpctc_op MODULES test_warpctc_op ENVS FLAGS_warpctc_dir=${WARPCTC_LIB_DIR}) py_test_modules(test_train_dyn_rnn MODULES test_dyn_rnn) py_test_modules(test_mul_op MODULES test_mul_op) +py_test_modules(test_network_with_dtype MODULES test_network_with_dtype) # tests that need to be run in separate process. py_test_modules(test_multihead_attention MODULES test_multihead_attention) diff --git a/python/paddle/fluid/tests/unittests/test_network_with_dtype.py b/python/paddle/fluid/tests/unittests/test_network_with_dtype.py index fe8aceb3ae42f73590bffe2a372c771654a372a9..d4835dd18405fc7a0d508a780a734922e0abd12c 100644 --- a/python/paddle/fluid/tests/unittests/test_network_with_dtype.py +++ b/python/paddle/fluid/tests/unittests/test_network_with_dtype.py @@ -24,33 +24,30 @@ BATCH_SIZE = 20 class TestNetWithDtype(unittest.TestCase): - def set_network(self): + def setUp(self): self.dtype = "float64" self.init_dtype() - main = fluid.Program() - with fluid.program_guard(main): - self.x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype) - self.y = fluid.layers.data(name='y', shape=[1], dtype=self.dtype) - y_predict = fluid.layers.fc(input=self.x, size=1, act=None) - cost = fluid.layers.square_error_cost(input=y_predict, label=self.y) + def run_net_on_place(self, place): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + x = fluid.layers.data(name='x', shape=[13], dtype=self.dtype) + y = fluid.layers.data(name='y', shape=[1], dtype=self.dtype) + y_predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=y_predict, label=y) avg_cost = fluid.layers.mean(cost) - self.program = main - self.fetch_list = [avg_cost] + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) + sgd_optimizer.minimize(avg_cost) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - sgd_optimizer.minimize(avg_cost) - - def run_net_on_place(self, place): + fetch_list = [avg_cost] train_reader = paddle.batch( paddle.dataset.uci_housing.train(), batch_size=BATCH_SIZE) - feeder = fluid.DataFeeder(place=place, feed_list=[self.x, self.y]) + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) + exe.run(startup) for data in train_reader(): - exe.run(self.program, - feed=feeder.feed(data), - fetch_list=self.fetch_list) + exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) # the main program is runable, the datatype is fully supported break @@ -58,14 +55,12 @@ class TestNetWithDtype(unittest.TestCase): pass def test_cpu(self): - self.set_network() place = fluid.CPUPlace() self.run_net_on_place(place) def test_gpu(self): if not core.is_compiled_with_cuda(): return - self.set_network() place = fluid.CUDAPlace(0) self.run_net_on_place(place) diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor.py b/python/paddle/fluid/tests/unittests/test_parallel_executor.py index 6dc016487fd81a9292f94042a20b7356bc50abe1..056f9e1781997aa1586d972874b652d5b725fe3f 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor.py @@ -775,7 +775,7 @@ class TestCRFModel(unittest.TestCase): build_strategy = fluid.BuildStrategy() build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce self.check_network_convergence( - is_sparse=False, build_strategy=build_strategy) + is_sparse=True, build_strategy=build_strategy) def test_update_dense_parameter_reduce(self): build_strategy = fluid.BuildStrategy() @@ -849,8 +849,7 @@ class TestFetchOp(unittest.TestCase): assert not math.isnan(np.sum(ret[i])) and \ not math.isinf(np.sum(ret[i])) - @unittest.skip("this test is buggy") - def test_feed(self): + def test_fetch_op(self): tst_reader = paddle.batch(flowers.test(use_xmap=False), batch_size=16) tst_reader_iter = tst_reader() diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index 6cf19c547756b4a985b74b45b938c9c243537b42..d158d586321833fdf046e4e061bfa8460b9a31b5 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -139,7 +139,40 @@ class Trainer(object): # load params from param_path into scope io.load_persistables(exe, dirname=param_path) + def _transpile_nccl2_dist(self): + # PADDLE_TRAINER_IPS + if "PADDLE_TRAINER_IPS" not in os.environ: + self.nccl_id_var = None + else: + self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) + port = os.getenv("PADDLE_PSERVER_PORT") + worker_ips = os.getenv("PADDLE_TRAINER_IPS") + worker_endpoints = [] + for ip in worker_ips.split(","): + worker_endpoints.append(':'.join([ip, port])) + self.num_trainers = len(worker_endpoints) + current_endpoint = os.getenv("POD_IP") + ":" + port + worker_endpoints.remove(current_endpoint) + # TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id + # in ParallelExecutor to start + # distributed training using NCCL2 + self.nccl_id_var = self.startup_program.global_block().create_var( + name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) + self.startup_program.global_block().append_op( + type="gen_nccl_id", + inputs={}, + outputs={"NCCLID": self.nccl_id_var}, + attrs={ + "endpoint": current_endpoint, + "endpoint_list": worker_endpoints, + "trainer_id": self.trainer_id + }) + def _dist_transpile_if_necessary(self, optimize_ops, params_grads): + self._transpile_nccl2_dist() + if self.nccl_id_var != None: + return + if "PADDLE_TRAINING_ROLE" not in os.environ: return diff --git a/tools/test_runner.py b/tools/test_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc750b89058cd73355a2f7984d577252c03526d --- /dev/null +++ b/tools/test_runner.py @@ -0,0 +1,48 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import sys +import paddle.fluid as fluid +import importlib +import cStringIO + + +def main(): + sys.path.append(os.getcwd()) + some_test_failed = False + for module_name in sys.argv[1:]: + buffer = cStringIO.StringIO() + main = fluid.Program() + startup = fluid.Program() + scope = fluid.core.Scope() + with fluid.program_guard(main, startup): + with fluid.scope_guard(scope): + with fluid.unique_name.guard(): + test_loader = unittest.TestLoader() + module = importlib.import_module(module_name) + tests = test_loader.loadTestsFromModule(module) + res = unittest.TextTestRunner(stream=buffer).run(tests) + if not res.wasSuccessful(): + some_test_failed = True + print >> sys.stderr, module_name, 'failed\n', buffer.getvalue( + ) + + if some_test_failed: + exit(1) + + +if __name__ == '__main__': + main()