diff --git a/CMakeLists.txt b/CMakeLists.txt index c7d743e193e7d32dbc0b56f3bcb05b6c61f85f1d..b174831109372cb014741d63032fa6a470e74042 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,8 +36,8 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) -option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) -option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) +option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) +option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 69220e03fe8e337205f31cb1f45e3e19ae4f5d1e..2ac098954647d37e26ac2499e0675dae39910edc 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -74,8 +74,6 @@ if(WITH_MKLDNN) set(OPENMP_FLAGS "-fopenmp") set(CMAKE_C_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) set(CMAKE_CXX_CREATE_SHARED_LIBRARY_FORBIDDEN_FLAGS ${OPENMP_FLAGS}) - set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -L${MKLDNN_IOMP_DIR} -liomp5 -Wl,--as-needed") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OPENMP_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OPENMP_FLAGS}") else() diff --git a/cmake/external/gflags.cmake b/cmake/external/gflags.cmake index a0d0a892c4b3cc3743ac725f3cd90444f18abf34..16e5bef4cdb8d6513de51838e3c3c8398dbad60d 100644 --- a/cmake/external/gflags.cmake +++ b/cmake/external/gflags.cmake @@ -28,7 +28,14 @@ INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR}) ExternalProject_Add( extern_gflags ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/gflags/gflags.git" + # TODO(yiwang): The annoying warnings mentioned in + # https://github.com/PaddlePaddle/Paddle/issues/3277 are caused by + # gflags. I fired a PR https://github.com/gflags/gflags/pull/230 + # to fix it. Before it gets accepted by the gflags team, we use + # my personal fork, which contains above fix, temporarily. Let's + # change this back to the official Github repo once my PR is + # merged. + GIT_REPOSITORY "https://github.com/wangkuiyi/gflags.git" PREFIX ${GFLAGS_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index cb86e6be2be3624bf54ee28193ca5d4c7bafa0eb..beb6793289812cfaa6991d28379126ff29fa2547 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_.get_eigen_device(); + return *device_context_->get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_.get_eigen_device(); + return *device_context_->get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d42e21c0a235791db42076555d0568ff8f4acbe2..b25362fef336fd84934e901108b6c8358463fe03 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -252,7 +252,7 @@ struct EigenDeviceConverter { class ExecutionContext : public OperatorContext { public: ExecutionContext(const OperatorBase* op, const Scope& scope, - const platform::DeviceContext& device_context) + const platform::DeviceContext* device_context) : OperatorContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> DeviceType& GetEigenDevice() const; - platform::Place GetPlace() const { return device_context_.GetPlace(); } + platform::Place GetPlace() const { return device_context_->GetPlace(); } - const platform::DeviceContext& device_context_; + const platform::DeviceContext* device_context_; }; class OpKernel { @@ -311,7 +311,7 @@ class OperatorWithKernel : public OperatorBase { void Run(const Scope& scope, const platform::DeviceContext& dev_ctx) const final { auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); - opKernel->Compute(ExecutionContext(this, scope, dev_ctx)); + opKernel->Compute(ExecutionContext(this, scope, &dev_ctx)); } static std::unordered_map& diff --git a/paddle/function/nnpack/NNPACKConvOp.cpp b/paddle/function/nnpack/NNPACKConvOp.cpp index f0ec77a5d00333993427fb8d0bc938c884e50c95..00d048eb216baf37c875c870a31cfd55a97f2974 100644 --- a/paddle/function/nnpack/NNPACKConvOp.cpp +++ b/paddle/function/nnpack/NNPACKConvOp.cpp @@ -49,9 +49,7 @@ class NNPACKConvFunction : public ConvFunctionBase { public: void init(const FuncConfig& config) override { ConvFunctionBase::init(config); - CHECK_EQ(groups_, (size_t)1); algorithm_ = get_nnp_convolution_algorithm(config.get("algo")); - // algorithm_ = nnp_convolution_algorithm_auto; transform_strategy_ = nnp_convolution_transform_strategy_compute; nnp_status status = nnp_initialize(); CHECK_EQ(status, nnp_status_success); @@ -67,8 +65,7 @@ public: } } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); @@ -91,8 +88,8 @@ public: size_t filterHeight = getFilterHeight(filter); size_t filterWidth = getFilterWidth(filter); size_t outputChannels = output[1]; - // size_t outputHeight = output[2]; - // size_t outputWidth = output[3]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; nnp_size inputSize = {.width = inputWidth, .height = inputHeight}; nnp_padding padding = {.top = (size_t)paddingH(), @@ -171,49 +168,58 @@ public: } } + size_t inputOffset = inputChannels / groups_ * inputHeight * inputWidth; + size_t outputOffset = outputChannels / groups_ * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + if (batchSize == 1) { - nnp_status status = - nnp_convolution_inference(algorithm_, - transform_strategy_, - inputChannels, - outputChannels, - inputSize, - padding, - kernelSize, - outputSubsampling, - inputData, - filterData, - nullptr, /* bias */ - outputData, - bufferPtr, - sizePtr, - nnp_activation_identity, - nullptr, - threadpool_, /* threadpool */ - nullptr); - CHECK_EQ(status, nnp_status_success); + for (size_t g = 0; g < groups_; g++) { + nnp_status status = + nnp_convolution_inference(algorithm_, + transform_strategy_, + inputChannels / groups_, + outputChannels / groups_, + inputSize, + padding, + kernelSize, + outputSubsampling, + inputData + inputOffset * g, + filterData + filterOffset * g, + nullptr, /* bias */ + outputData + outputOffset * g, + bufferPtr, + sizePtr, + nnp_activation_identity, + nullptr, + threadpool_, /* threadpool */ + nullptr); + CHECK_EQ(status, nnp_status_success); + } } else { - // only supports stride = 1 - CHECK_EQ(strideH(), 1); - CHECK_EQ(strideW(), 1); - nnp_status status = nnp_convolution_output(algorithm_, - batchSize, - inputChannels, - outputChannels, - inputSize, - padding, - kernelSize, - inputData, - filterData, - nullptr, /* bias */ - outputData, - bufferPtr, - sizePtr, - nnp_activation_identity, - nullptr, - threadpool_, /* threadpool */ - nullptr); - CHECK_EQ(status, nnp_status_success); + for (size_t g = 0; g < groups_; g++) { + // only supports stride = 1 + CHECK_EQ(strideH(), 1); + CHECK_EQ(strideW(), 1); + nnp_status status = + nnp_convolution_output(algorithm_, + batchSize, + inputChannels / groups_, + outputChannels / groups_, + inputSize, + padding, + kernelSize, + inputData + inputOffset * g, + filterData + filterOffset * g, + nullptr, /* bias */ + outputData + outputOffset * g, + bufferPtr, + sizePtr, + nnp_activation_identity, + nullptr, + threadpool_, /* threadpool */ + nullptr); + CHECK_EQ(status, nnp_status_success); + } } } diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index 783e02e47cb91e28eb88b079f1e94439d34fa775..0ece2799318ea5ecc91f97f71289d4d07246dcaa 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -57,8 +57,7 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, convGradFilterType = "GemmConvGradFilter"; } - if (FLAGS_use_nnpack) { - CHECK_EQ(isDeconv_, false); + if (FLAGS_use_nnpack && !isDeconv_) { createFunction(forward_, "NNPACKConv", FuncConfig() diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 9ee66c2c5103811519c3a2c28653536f97009161..e6bc7d8a9b5ddd4582a5ef8a47cb63a7e5911892 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -33,23 +33,28 @@ class OpTestMeta(type): for place in places: for in_name in func.all_input_args: - if hasattr(self, in_name): + if hasattr(self, "inputs") and in_name in self.inputs: kwargs[in_name] = in_name var = scope.new_var(in_name).get_tensor() - arr = getattr(self, in_name) + arr = self.inputs[in_name] var.set_dims(arr.shape) var.set(arr, place) else: kwargs[in_name] = "@EMPTY@" for out_name in func.all_output_args: - if hasattr(self, out_name): - kwargs[out_name] = out_name - scope.new_var(out_name).get_tensor() + if not hasattr(self, "outputs"): + raise ValueError( + "The test op must set self.outputs dict.") + if out_name not in self.outputs: + raise ValueError("The %s is not in self.outputs dict." % + (out_name)) + kwargs[out_name] = out_name + scope.new_var(out_name).get_tensor() for attr_name in func.all_attr_args: - if hasattr(self, attr_name): - kwargs[attr_name] = getattr(self, attr_name) + if hasattr(self, "attrs") and attr_name in self.attrs: + kwargs[attr_name] = self.attrs[attr_name] op = func(**kwargs) @@ -60,7 +65,7 @@ class OpTestMeta(type): for out_name in func.all_output_args: actual = numpy.array(scope.find_var(out_name).get_tensor()) - expect = getattr(self, out_name) + expect = self.outputs[out_name] numpy.isclose(actual, expect) obj.test_all = test_all diff --git a/python/paddle/v2/framework/tests/test_add_two_op.py b/python/paddle/v2/framework/tests/test_add_two_op.py index 6e6643201bf361fce1bad7de10b2562f0525e00a..8ef48f4727b0af46a696c6f463045d98e7a08800 100644 --- a/python/paddle/v2/framework/tests/test_add_two_op.py +++ b/python/paddle/v2/framework/tests/test_add_two_op.py @@ -12,9 +12,11 @@ class TestAddOp(unittest.TestCase): def setUp(self): self.type = "add_two" - self.X = numpy.random.random((102, 105)).astype("float32") - self.Y = numpy.random.random((102, 105)).astype("float32") - self.Out = self.X + self.Y + self.inputs = { + 'X': numpy.random.random((102, 105)).astype("float32"), + 'Y': numpy.random.random((102, 105)).astype("float32") + } + self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']} class TestAddGradOp(unittest.TestCase): diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index 6d022f6bc0be60dbf2f796780a969bff0e8bfded..b26e25d58b59bd1cb16e9ba2a1cccd27799b15f2 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -7,15 +7,17 @@ class TestSGD(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): + # TODO this unit test is not passed self.type = "onehot_cross_entropy" batch_size = 100 class_num = 10 - self.X = numpy.random.random((batch_size, class_num)).astype("float32") - self.label = 5 * numpy.ones(batch_size).astype("int32") + X = numpy.random.random((batch_size, class_num)).astype("float32") + label = 5 * numpy.ones(batch_size).astype("int32") + self.inputs = {'X': X, 'label': label} Y = [] for i in range(0, batch_size): - Y.append(-numpy.log(self.X[i][self.label[i]])) - self.Y = numpy.array(Y).astype("float32") + Y.append(-numpy.log(X[i][label[i]])) + self.outputs = {'Y': numpy.array(Y).astype("float32")} # TODO(superjom) add gradient check diff --git a/python/paddle/v2/framework/tests/test_mean_op.py b/python/paddle/v2/framework/tests/test_mean_op.py index 78fff1eeff998109a51ea662f963a102eff49d3a..b5d52b90567bcd0c9f376147145d8638049f7bab 100644 --- a/python/paddle/v2/framework/tests/test_mean_op.py +++ b/python/paddle/v2/framework/tests/test_mean_op.py @@ -8,8 +8,8 @@ class TestMeanOp(unittest.TestCase): def setUp(self): self.type = "mean" - self.X = np.random.random((32, 784)).astype("float32") - self.Out = np.mean(self.X) + self.inputs = {'X': np.random.random((32, 784)).astype("float32")} + self.outputs = {'Out': np.mean(self.inputs['X'])} if __name__ == '__main__': diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index e1ac66d3a4d23d617f7c5a4d97d070b2660954c8..ec0ac99156a546dd3fb7b27778032bece38ab5a9 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -8,9 +8,11 @@ class TestMulOp(unittest.TestCase): def setUp(self): self.type = "mul" - self.X = np.random.random((32, 84)).astype("float32") - self.Y = np.random.random((84, 100)).astype("float32") - self.Out = np.dot(self.X, self.Y) + self.inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} if __name__ == '__main__': diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 04abc14ee198fe4e2307e009c696a2b40ec271b6..f8521eb517057fbeb104b28af7da4fffe54f37de 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -8,9 +8,11 @@ class TestRowwiseAddOp(unittest.TestCase): def setUp(self): self.type = "rowwise_add" - self.X = np.random.random((32, 84)).astype("float32") - self.b = np.random.random(84).astype("float32") - self.Out = np.add(self.X, self.b) + self.inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'b': np.random.random(84).astype("float32") + } + self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} if __name__ == '__main__': diff --git a/python/paddle/v2/framework/tests/test_sgd_op.py b/python/paddle/v2/framework/tests/test_sgd_op.py index ca03cc11abe2ceb31b33a87797aa752943dd2a7d..e5f9ef865e84f1a78e28884ad7e2e758f9ca8054 100644 --- a/python/paddle/v2/framework/tests/test_sgd_op.py +++ b/python/paddle/v2/framework/tests/test_sgd_op.py @@ -8,10 +8,13 @@ class TestSGD(unittest.TestCase): def setUp(self): self.type = "sgd" - self.param = numpy.random.random((102, 105)).astype("float32") - self.grad = numpy.random.random((102, 105)).astype("float32") - self.learning_rate = 0.1 - self.param_out = self.param - self.learning_rate * self.grad + w = numpy.random.random((102, 105)).astype("float32") + g = numpy.random.random((102, 105)).astype("float32") + lr = 0.1 + + self.inputs = {'param': w, 'grad': g} + self.attrs = {'learning_rate': lr} + self.outputs = {'param_out': w - lr * g} if __name__ == "__main__": diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 50044a122f1d66dd54a24f6cce76074a60ee2262..2610bcf16303d492dce3ce63c93b54b0c88f6bba 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -8,8 +8,8 @@ class TestSigmoidOp(unittest.TestCase): def setUp(self): self.type = "sigmoid" - self.X = np.random.random((32, 100)).astype("float32") - self.Y = 1 / (1 + np.exp(-self.X)) + self.inputs = {'X': np.random.random((32, 100)).astype("float32")} + self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} if __name__ == '__main__': diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py index c80888128781d98e4ed30d845a30b39121f66459..98ca8ddc860c3825411b02b2f6ed612db46a18d7 100644 --- a/python/paddle/v2/framework/tests/test_softmax_op.py +++ b/python/paddle/v2/framework/tests/test_softmax_op.py @@ -19,8 +19,10 @@ class TestSoftmaxOp(unittest.TestCase): def setUp(self): self.type = "softmax" - self.X = np.random.random((32, 100)).astype("float32") - self.Y = np.apply_along_axis(stable_softmax, 1, self.X) + self.inputs = {'X': np.random.random((32, 100)).astype("float32")} + self.outputs = { + 'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) + } class TestSoftmaxGradOp(unittest.TestCase):