diff --git a/paddle/framework/attr_checker.h b/paddle/framework/attr_checker.h index c0c33d81149ac2fc2a9a57d90931ef32375fe1d0..f2d88f3cb00e20f548a5cd412b515e843491a76d 100644 --- a/paddle/framework/attr_checker.h +++ b/paddle/framework/attr_checker.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "paddle/framework/enforce.h" @@ -41,6 +42,35 @@ class DefaultValueSetter { T default_value_; }; +template +class EnumInContainer { + public: + explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} + void operator()(T& val) const { + PADDLE_ENFORCE(container_.find(val) != container_.end(), + "Value %s is not in enum container %s", val, + ContainerDebugString()); + } + + private: + std::string ContainerDebugString() const { + std::ostringstream sout; + sout << "["; + size_t cnt = 0; + for (auto& v : container_) { + sout << v; + ++cnt; + if (cnt != container_.size()) { + sout << " ,"; + } + } + sout << "]"; + return sout.str(); + } + + std::unordered_set container_; +}; + // check whether a certain attribute fit its limits // an attribute can have more than one limits template @@ -50,6 +80,11 @@ class TypedAttrChecker { public: TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {} + TypedAttrChecker& InEnum(const std::unordered_set& range) { + value_checkers_.push_back(EnumInContainer(range)); + return *this; + } + TypedAttrChecker& LargerThan(const T& lower_bound) { value_checkers_.push_back(LargerThanChecker(lower_bound)); return *this; diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index c39c87fcd6dc7d8b78c8112b0f258774e2bf74d7..2abc5d341769a7f9c3acc570b8c5e01c37f1454c 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -33,7 +33,9 @@ std::shared_ptr AddBackwardOp(std::shared_ptr ForwardOps) { return grad_ops; } -void PlainNet::CompleteAddOp() { +void PlainNet::CompleteAddOp(bool calc) { + add_op_done_ = true; + if (!calc) return; std::unordered_set input_set; std::unordered_set output_set; std::unordered_set temp_output; @@ -66,7 +68,6 @@ void PlainNet::CompleteAddOp() { } attrs_["temporary_index"] = tmp_index; - add_op_done_ = true; } std::string PlainNet::DebugString() const { diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 1103c8ef2b01697aa3a92402a3325a1a8e6c700b..60bfd3ef5e8a1cc48d7e583d5863a25609ce82b6 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -16,7 +16,6 @@ limitations under the License. */ #include #include -#include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" @@ -41,7 +40,7 @@ namespace framework { class Net : public OperatorBase { public: virtual void AddOp(const OperatorPtr& op) = 0; - virtual void CompleteAddOp() = 0; + virtual void CompleteAddOp(bool calc) = 0; }; using NetPtr = std::shared_ptr; @@ -86,7 +85,7 @@ class PlainNet : public Net { ops_.push_back(op); } - void CompleteAddOp() override; + void CompleteAddOp(bool calculate = true) override; std::string DebugString() const override; diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 4f07350e59dea72431417876f41f172e51ea53f9..1dd421cdb681e15486e309ff912574af35b5a0c2 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -71,14 +71,14 @@ class Tensor { holder_.reset(new PlaceholderImpl( boost::get(place), product(dims_) * sizeof(T))); } else if (platform::is_gpu_place(place)) { -#ifdef __CUDACC__ +#ifdef PADDLE_ONLY_CPU + PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); +#else holder_.reset(new PlaceholderImpl( boost::get(place), product(dims_) * sizeof(T))); -#else - PADDLE_ENFORCE(true, "'GPUPlace' is not supported in CPU only device."); #endif } else { - PADDLE_ENFORCE(true, "Unknown 'place'."); + PADDLE_THROW("Unknown 'place'."); } offset_ = 0; } diff --git a/paddle/gserver/layers/Layer.cpp b/paddle/gserver/layers/Layer.cpp index 4b92b5d163ad107c0783beae45f8c936112fcccf..d5621412caee843e24a0d0c9b7096402765738c7 100644 --- a/paddle/gserver/layers/Layer.cpp +++ b/paddle/gserver/layers/Layer.cpp @@ -359,12 +359,11 @@ void Layer::backwardActivation() { /* Do error clipping */ if (config_.error_clipping_threshold() > 0.0f) { if (FLAGS_log_error_clipping) { - CpuVector outGradVec(0, nullptr); - outGradVec.subVecFrom( - output_.grad->getData(), 0, output_.grad->getElementCnt()); - real maxAbsGrad = outGradVec.getAbsMax(); + VectorPtr outGradVec = Vector::create( + output_.grad->getData(), output_.grad->getElementCnt(), useGpu_); + real maxAbsGrad = outGradVec->getAbsMax(); if (maxAbsGrad > config_.error_clipping_threshold()) { - real avgAbsGrad = outGradVec.getAbsSum() / outGradVec.getSize(); + real avgAbsGrad = outGradVec->getAbsSum() / outGradVec->getSize(); LOG(INFO) << " layer=" << config_.name() << " need clipping," << " max error=" << maxAbsGrad << " avg error=" << avgAbsGrad; } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f47c3a42083f289d6c99fe6df62e3478e0363e31..bc64bfd7ec2ed27835e5a3f9135343aeb3d4a580 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -27,7 +27,8 @@ function(op_library TARGET) endif() list(LENGTH cu_srcs cu_srcs_len) - if (${cu_srcs_len} EQUAL 0) + list(LENGTH op_library_DEPS dep_len) + if (${cu_srcs_len} EQUAL 0 AND ${dep_len} EQUAL 0) message(WARNING "The op library ${TARGET} not support GPU!") endif() @@ -47,3 +48,6 @@ op_library(mul_op SRCS mul_op.cc mul_op.cu) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc) op_library(softmax_op SRCS softmax_op.cc softmax_op.cu) + +op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op + softmax_op net) diff --git a/paddle/operators/fc_op.cc b/paddle/operators/fc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..01e96f4c4817466e3266ca57a0d0ae2368b3e097 --- /dev/null +++ b/paddle/operators/fc_op.cc @@ -0,0 +1,76 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/net.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +namespace paddle { +namespace operators { + +class FullyConnectedOp : public framework::PlainNet { +public: + void Init() override { + AddOp(framework::OpRegistry::CreateOp("mul", + { + Input("X"), Input("W"), + }, + {Output("before_act")}, + {})); + auto b = Input("b"); + if (b != framework::OperatorBase::EMPTY_VAR_NAME()) { + AddOp(framework::OpRegistry::CreateOp("rowwise_add", + {Output("before_act"), Input("b")}, + {Output("before_act")}, + {})); + } + + auto activation = GetAttr("activation"); + AddOp(framework::OpRegistry::CreateOp( + activation, {Output("before_act")}, {Output("Y")}, {})); + CompleteAddOp(false); + } +}; + +class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker { +public: + FullyConnectedOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input of fc operator"); + AddInput("W", "the weight of fc operator"); + AddInput("b", "the bias of fc operator"); + + AddOutput("Y", "the output of fc operator"); + AddOutput( + "before_act", "the before activation output of fc operator", true); + AddAttr("activation", "The activation key for fc layer") + .SetDefault("sigmoid") + .InEnum({"sigmoid", "softmax"}); + + //! TODO(yuyang18): Complete comment; + AddComment("FullyConnected Operator"); + } +}; +} // namespace operators +} // namespace paddle + +USE_OP(mul); +USE_OP(rowwise_add); +USE_OP(sigmoid); +USE_OP(softmax); + +REGISTER_OP(fc, + paddle::operators::FullyConnectedOp, + paddle::operators::FullyConnectedOpMaker); diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 00b14a94321990baef6de35df547eed04b3da04f..29fb29c7c14f699e6114cc25c265ea8d85bce4d7 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,2 +1,2 @@ cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python - add_op mul_op rowwise_add_op sigmoid_op softmax_op) + add_op fc_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fc9c6544c3cbf5a804b2d052f738bd483d6bf41b..7e84550f770e8dba998ce7ff91b9d774acbffc3e 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -26,10 +27,7 @@ namespace py = pybind11; namespace pd = paddle::framework; USE_OP(add_two); -USE_OP(softmax); -USE_OP(mul); -USE_OP(rowwise_add); -USE_OP(sigmoid); +USE_OP_WITHOUT_KERNEL(fc); PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); @@ -53,7 +51,9 @@ PYBIND11_PLUGIN(core) { self.mutable_data(paddle::platform::CPUPlace()); }) .def("set", paddle::pybind::PyTensorSetFromArray) - .def("set", paddle::pybind::PyTensorSetFromArray); + .def("set", paddle::pybind::PyTensorSetFromArray) + .def("shape", + [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); py::class_(m, "Variable", R"DOC(Variable Class. @@ -83,15 +83,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. - m.def("get_all_op_protos", []() -> std::vector { + m.def("get_all_op_protos", []() -> std::vector { auto& protos = pd::OpRegistry::protos(); - std::vector ret_values; + std::vector ret_values; for (auto it = protos.begin(); it != protos.end(); ++it) { PADDLE_ENFORCE(it->second.IsInitialized(), "OpProto must all be initialized"); - ret_values.emplace_back(); - PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), + std::string str; + PADDLE_ENFORCE(it->second.SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); + ret_values.push_back(py::bytes(str)); } return ret_values; }); @@ -101,17 +102,26 @@ All parameter, weight, gradient are variables in Paddle. .def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("temp", pd::OperatorBase::TMP_VAR_NAME); + py::class_(m, "DeviceContext") + .def_static("cpu_context", []() -> paddle::platform::DeviceContext* { + return new paddle::platform::CPUDeviceContext(); + }); + py::class_(m, "Operator") .def("__str__", &pd::OperatorBase::DebugString) - .def_static("create", [](const std::string& protobin) { - pd::OpDesc desc; - PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), - "Cannot parse user input to OpDesc"); - PADDLE_ENFORCE(desc.IsInitialized(), - "User OpDesc is not initialized, reason %s", - desc.InitializationErrorString()); - return pd::OpRegistry::CreateOp(desc); - }); + .def_static("create", + [](py::bytes protobin) { + pd::OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + return pd::OpRegistry::CreateOp(desc); + }) + .def("infer_shape", &pd::OperatorBase::InferShape) + .def("run", &pd::OperatorBase::Run) + .def("outputs", [](const pd::OperatorPtr& op) { return op->outputs_; }); return m.ptr(); } diff --git a/paddle/scripts/travis/check_style.sh b/paddle/scripts/travis/check_style.sh index 8049aeb7b00870220e59c981addf6d70a66877c7..ec499a839ac6593bac788f4cca5e33afbed73010 100755 --- a/paddle/scripts/travis/check_style.sh +++ b/paddle/scripts/travis/check_style.sh @@ -1,7 +1,7 @@ #!/bin/bash function abort(){ echo "Your change doesn't follow PaddlePaddle's code style." 1>&2 - echo "Please use pre-commit to reformat your code and git push again." 1>&2 + echo "Please use pre-commit to check what is wrong." 1>&2 exit 1 } @@ -19,7 +19,8 @@ ln -sf $TRAVIS_BUILD_DIR $GOPATH/src/github.com/PaddlePaddle/Paddle cd $GOPATH/src/github.com/PaddlePaddle/Paddle/go; glide install; cd - if ! pre-commit run -a ; then - git diff --exit-code + git diff + exit 1 fi trap : 0 diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 826ba2834a820d11e69feec5569ef3537194e3c3..ef3d81e4c0791ca7847dc607682fa39ff15967da 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -1575,7 +1575,13 @@ class MultiClassCrossEntropySelfNormCostLayer(LayerBase): @config_layer('fc') class FCLayer(LayerBase): - def __init__(self, name, size, inputs, bias=True, **xargs): + def __init__(self, + name, + size, + inputs, + bias=True, + error_clipping_threshold=None, + **xargs): super(FCLayer, self).__init__(name, 'fc', size, inputs=inputs, **xargs) for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) @@ -1592,6 +1598,8 @@ class FCLayer(LayerBase): self.create_input_parameter(input_index, psize, dims, sparse, format) self.create_bias_parameter(bias, self.config.size) + if error_clipping_threshold is not None: + self.config.error_clipping_threshold = error_clipping_threshold @config_layer('selective_fc') diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 2e4beb6882789249db09705f3f4d6c5c19e492cd..90830515c1e8e6f5260cfca631e02a3a52cedbe5 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -26,8 +26,9 @@ import sentiment import wmt14 import mq2007 import flowers +import voc2012 __all__ = [ 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14', 'mq2007', 'flowers' + 'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012' ] diff --git a/python/paddle/v2/dataset/tests/voc2012_test.py b/python/paddle/v2/dataset/tests/voc2012_test.py new file mode 100644 index 0000000000000000000000000000000000000000..31e72ebf5eac0508d12783f9ceaa6eef0fa6d353 --- /dev/null +++ b/python/paddle/v2/dataset/tests/voc2012_test.py @@ -0,0 +1,42 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.voc2012 +import unittest + + +class TestVOC(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + label = 0 + for l in reader(): + self.assertEqual(l[0].size, 3 * l[1].size) + sum += 1 + return sum + + def test_train(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.train()) + self.assertEqual(count, 2913) + + def test_test(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.test()) + self.assertEqual(count, 1464) + + def test_val(self): + count = self.check_reader(paddle.v2.dataset.voc_seg.val()) + self.assertEqual(count, 1449) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/dataset/voc2012.py b/python/paddle/v2/dataset/voc2012.py new file mode 100644 index 0000000000000000000000000000000000000000..617e212d67fbe37f9d9663e9c83c62045411fa77 --- /dev/null +++ b/python/paddle/v2/dataset/voc2012.py @@ -0,0 +1,85 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Image dataset for segmentation. +The 2012 dataset contains images from 2008-2011 for which additional +segmentations have been prepared. As in previous years the assignment +to training/test sets has been maintained. The total number of images +with segmentation has been increased from 7,062 to 9,993. +""" + +import tarfile +import io +import numpy as np +from paddle.v2.dataset.common import download +from paddle.v2.image import * +from PIL import Image + +__all__ = ['train', 'test', 'val'] + +VOC_URL = 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/\ +VOCtrainval_11-May-2012.tar' + +VOC_MD5 = '6cd6e144f989b92b3379bac3b3de84fd' +SET_FILE = 'VOCdevkit/VOC2012/ImageSets/Segmentation/{}.txt' +DATA_FILE = 'VOCdevkit/VOC2012/JPEGImages/{}.jpg' +LABEL_FILE = 'VOCdevkit/VOC2012/SegmentationClass/{}.png' + +CACHE_DIR = 'voc2012' + + +def reader_creator(filename, sub_name): + + tarobject = tarfile.open(filename) + name2mem = {} + for ele in tarobject.getmembers(): + name2mem[ele.name] = ele + + def reader(): + set_file = SET_FILE.format(sub_name) + sets = tarobject.extractfile(name2mem[set_file]) + for line in sets: + line = line.strip() + data_file = DATA_FILE.format(line) + label_file = LABEL_FILE.format(line) + data = tarobject.extractfile(name2mem[data_file]).read() + label = tarobject.extractfile(name2mem[label_file]).read() + data = Image.open(io.BytesIO(data)) + label = Image.open(io.BytesIO(label)) + data = np.array(data) + label = np.array(label) + yield data, label + + return reader + + +def train(): + """ + Create a train dataset reader containing 2913 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'trainval') + + +def test(): + """ + Create a test dataset reader containing 1464 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'train') + + +def val(): + """ + Create a val dataset reader containing 1449 images in HWC order. + """ + return reader_creator(download(VOC_URL, CACHE_DIR, VOC_MD5), 'val') diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 4ce2bef6fcc4b8ddf5a6de3809a1891bce590aab..b75b7442d1e7d0f1846db057ea8fd173b4ab7507 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,3 +1,3 @@ add_python_test(test_framework test_protobuf.py test_scope.py test_default_scope_funcs.py test_op_creation_methods.py - test_tensor.py) + test_tensor.py test_fc_op.py) diff --git a/python/paddle/v2/framework/tests/test_fc_op.py b/python/paddle/v2/framework/tests/test_fc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..59e7e61249e2a7d49a17e5d87209f03b8f35f730 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_fc_op.py @@ -0,0 +1,43 @@ +import paddle.v2.framework.core as core +import unittest +import numpy +import paddle.v2.framework.create_op_creation_methods as creation + + +class TestFc(unittest.TestCase): + def test_fc(self): + scope = core.Scope(None) + x = scope.create_var("X") + x_tensor = x.get_tensor() + x_tensor.set_dims([1000, 784]) + x_tensor.alloc_float() + + w = scope.create_var("W") + w_tensor = w.get_tensor() + w_tensor.set_dims([784, 100]) + w_tensor.alloc_float() + + w_tensor.set(numpy.random.random((784, 100)).astype("float32")) + + # Set a real numpy array here. + # x_tensor.set(numpy.array([])) + + op = creation.op_creations.fc(X="X", Y="Y", W="W") + + for out in op.outputs(): + if scope.get_var(out) is None: + scope.create_var(out).get_tensor() + + tensor = scope.get_var("Y").get_tensor() + op.infer_shape(scope) + self.assertEqual([1000, 100], tensor.shape()) + + ctx = core.DeviceContext.cpu_context() + + op.run(scope, ctx) + + # After complete all ops, check Y is expect or not. + + +if __name__ == '__main__': + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index b1041f6102a56f5a200aa909e77729095c052f31..65a26940d4d703ea4fbb5022523a90716982ec10 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -20,6 +20,7 @@ setup_requires=["requests", "matplotlib", "rarfile", "scipy>=0.19.0", + "Pillow", "nltk"] if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: