diff --git a/paddle/gserver/layers/DetectionOutputLayer.cpp b/paddle/gserver/layers/DetectionOutputLayer.cpp index 8ab838e191314ab25469631626c0b0564d7fffda..0cf0a92bf4bd8f9b8eba2016b2377d9dfb18c70a 100644 --- a/paddle/gserver/layers/DetectionOutputLayer.cpp +++ b/paddle/gserver/layers/DetectionOutputLayer.cpp @@ -139,7 +139,13 @@ void DetectionOutputLayer::forward(PassType passType) { allDecodedBBoxes, &allIndices); - resetOutput(numKept, 7); + if (numKept > 0) { + resetOutput(numKept, 7); + } else { + MatrixPtr outV = getOutputValue(); + outV = NULL; + return; + } MatrixPtr outV = getOutputValue(); getDetectionOutput(confBuffer_->getData(), numKept, diff --git a/paddle/gserver/layers/DetectionUtil.cpp b/paddle/gserver/layers/DetectionUtil.cpp index 3e61adc66e60c54250e4f323452aa13045310879..d83674f45a70212a8adc94a31ff58eb0e01baa00 100644 --- a/paddle/gserver/layers/DetectionUtil.cpp +++ b/paddle/gserver/layers/DetectionUtil.cpp @@ -469,7 +469,7 @@ size_t getDetectionIndices( const size_t numClasses, const size_t backgroundId, const size_t batchSize, - const size_t confThreshold, + const real confThreshold, const size_t nmsTopK, const real nmsThreshold, const size_t keepTopK, diff --git a/paddle/gserver/layers/DetectionUtil.h b/paddle/gserver/layers/DetectionUtil.h index fe4f9f075e4cf011c97f68f49598a828d62327b3..641ed873b4c8645b6455e5ef5e63593e3005b770 100644 --- a/paddle/gserver/layers/DetectionUtil.h +++ b/paddle/gserver/layers/DetectionUtil.h @@ -275,7 +275,7 @@ size_t getDetectionIndices( const size_t numClasses, const size_t backgroundId, const size_t batchSize, - const size_t confThreshold, + const real confThreshold, const size_t nmsTopK, const real nmsThreshold, const size_t keepTopK, diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index 8318c8c519a4cec1610eadd28320ee5ce0b4147d..53433cef35a377a73f87b041fdcfadd848dd2ec9 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -77,24 +77,6 @@ void MKLDNNFcLayer::convertWeightsToPaddle() { wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); } -void MKLDNNFcLayer::convertOutputToOtherDevice() { - copyOutputInfoToOtherDevice(); - // find other cpu device and reorder output to cpu device - int cnt = 0; - for (size_t i = 0; i < outputOtherDevice_.size(); i++) { - if (outputOtherDevice_[i].deviceId == CPU_DEVICE) { - // fc cpu output value do not need convert - // just share point - outputOtherDevice_[i].value = output_.value; - ++cnt; - } - } - - if (cnt > 1) { - LOG(WARNING) << "should not have more than one CPU devie"; - } -} - void MKLDNNFcLayer::reshape() { const Argument& input = getInput(0, getPrev(0)->getDeviceId()); int batchSize = input.getBatchSize(); @@ -155,7 +137,10 @@ void MKLDNNFcLayer::resetFwd() { // change original output value to mkldnn output value output_.value = std::dynamic_pointer_cast(outVal_); if (!outputIsOnlyMKLDNN()) { - convertOutputToOtherDevice(); + copyOutputInfoToOtherDevice(); + // fc cpu output value do not need create convert + // just share point + getOutput(CPU_DEVICE).value->setData(output_.value->getData()); } // create forward handle @@ -235,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() { pipelineBwd_.push_back(*bwdWgt_); /// backward data - device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; - const MatrixPtr& in = getInputGrad(0, device); + const MatrixPtr& in = inputLayers_[0]->getOutput().grad; if (in == nullptr) { return; } - if (getInput(0, device).getAllCount() > 1) { - // TODO(TJ): use outputMaps_ ways when merge outgrad done + if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) { + // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done } else { inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc()); } @@ -258,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() { pipelineBwd_.push_back(*bwdData_); } +void MKLDNNFcLayer::updateInputData() { + if (inputLayers_[0]->getType() != "data") { + return; + } + real* iData = getInputValue(0, CPU_DEVICE)->getData(); + inVal_->setData(iData); +} + void MKLDNNFcLayer::forward(PassType passType) { Layer::forward(passType); reshape(); { REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); - syncInputValue(); + updateInputData(); // just submit forward pipeline stream_->submit(pipelineFwd_); @@ -286,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) { REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); resetBwd(); - syncOutputGrad(); // just sumbmit backward pipeline stream_->submit(pipelineBwd_); } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index e138a6faf181c412949218458e7ecf800a0d6a07..4ad67a16e056a718c45a28babcf22a7cd571b15c 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -53,6 +53,8 @@ public: void backward(const UpdateCallback& callback) override; + void updateInputData() override; + protected: /** * reshape the input image sizes @@ -72,8 +74,6 @@ protected: * only would be called when needed */ void resetBwd(); - - void convertOutputToOtherDevice() override; }; } // namespace paddle diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index b983b833d510b823c5d4cff0b9390173e4cefc89..543364edceff684bdcd002a8f4f10e7ce5e6953b 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -114,10 +114,10 @@ public: virtual void convertWeightsToPaddle() {} /** - * convert MKLDNN output to other device. - * only support CPU device yet + * Update input value data when input layer is "data" type. + * Since the input value data address might be changed. */ - virtual void convertOutputToOtherDevice() {} + virtual void updateInputData() {} /** * print info about sizes @@ -155,6 +155,7 @@ protected: * copy base info and do not copy data value */ void copyOutputInfoToOtherDevice() { + int cnt = 0; for (size_t i = 0; i < outputOtherDevice_.size(); i++) { outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight()); outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth()); @@ -163,6 +164,12 @@ protected: outputOtherDevice_[i].subSequenceStartPositions = output_.subSequenceStartPositions; outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims; + if (outputOtherDevice_[i].deviceId == CPU_DEVICE) { + ++cnt; + } + } + if (cnt > 1) { + LOG(WARNING) << "should not have more than one CPU devie"; } } @@ -193,32 +200,6 @@ protected: return outputOtherDevice_.size() == 0; } - /** - * Sync input value data - */ - void syncInputValue() { - if (inputIsOnlyMKLDNN()) { - return; - } - real* iData = getInputValue(0, CPU_DEVICE)->getData(); - // update input data - // since it might be changed if this is after data layer - inVal_->updateData(iData); - } - - /** - * Sync output grad data - */ - void syncOutputGrad() { - if (outputIsOnlyMKLDNN()) { - return; - } - - // update diff - real* oDiff = getOutput(CPU_DEVICE).grad->getData(); - outGrad_->updateData(oDiff); - } - /** * Set deviceId of this layer. */ diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp index 0a355e2644cce572ce90ecf5c9d2a5b7b395bc61..c4063e5069854242d9f93886b66580385557ca73 100644 --- a/paddle/math/MKLDNNMatrix.cpp +++ b/paddle/math/MKLDNNMatrix.cpp @@ -33,14 +33,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { size_t width = cnts / dims[0]; m = Matrix::create(height, width, false, false); } - CHECK(m) << " Matrix should not be empty"; + CpuMatrixPtr cpuMatrix = std::dynamic_pointer_cast(m); CHECK(cpuMatrix) << "Only support create from CPU matrix yet"; - - CHECK_EQ(cnts, m->getElementCnt()) << "Count size does not match"; - return std::make_shared( - m->getData(), m->getHeight(), m->getWidth(), pd); + CHECK_EQ(cpuMatrix->getElementCnt(), cnts) << "Count size does not match"; + return std::make_shared(cpuMatrix, pd); } MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, @@ -138,7 +136,7 @@ void MKLDNNMatrix::downSpatial() { mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr), "could not create a memory primitive"); reset(result); - set_data_handle(getData()); + set_data_handle(data_); } } // namespace paddle diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index e50f698b495713e6f15ab7a12a7ee7487662040f..eef3b429e6fa0087aeac3f5aed9dff983b06e826 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -30,11 +30,10 @@ typedef std::shared_ptr MKLDNNMatrixPtr; */ class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory { public: - MKLDNNMatrix(real* data, - size_t height, - size_t width, - mkldnn::memory::primitive_desc pd) - : CpuMatrix(data, height, width, false), mkldnn::memory(pd, data) {} + MKLDNNMatrix(CpuMatrixPtr m, mkldnn::memory::primitive_desc pd) + : CpuMatrix(m->getData(), m->getHeight(), m->getWidth(), false), + mkldnn::memory(pd, m->getData()), + m_(m) {} ~MKLDNNMatrix() {} @@ -81,11 +80,29 @@ public: void downSpatial(); /** - * Update the memory data handle. + * set the memory data handle. * Caution: This will not check the buffer size of the data, * it should be coverd by user. */ - void updateData(void* data) { set_data_handle(data); } + void setData(real* data) { + set_data_handle(data); + CpuMatrix::setData(data); + m_.reset(); + } + + /** + * override Matrix::getData + * check data before return + */ + real* getData() override { + CHECK_EQ((void*)data_, get_data_handle()); + return data_; + } + + const real* getData() const override { + CHECK_EQ((void*)data_, get_data_handle()); + return data_; + } /** * Get primitive descriptor. @@ -143,6 +160,10 @@ protected: memory::format srcFmt, memory::format dstFmt, memory::dims dm); + +private: + // save the CpuMatrixPtr in case the buffer released outside + CpuMatrixPtr m_; }; } // namespace paddle diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..0ebefbab26ec8fdf316f852fbb7f6d9f3bbc48eb --- /dev/null +++ b/paddle/operators/concat_op.cc @@ -0,0 +1,79 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/concat_op.h" +#include + +namespace paddle { +namespace operators { +using framework::Tensor; + +class ConcatOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto ins = ctx.MultiInput("X"); + auto *out = ctx.Output("Out"); + size_t axis = static_cast(ctx.Attr("axis")); + size_t n = ins.size(); + + PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1."); + + auto out_dims = ins[0]->dims(); + size_t in_zero_dims_size = out_dims.size(); + for (size_t i = 1; i < n; i++) { + for (size_t j = 0; j < in_zero_dims_size; j++) { + if (j == axis) { + out_dims[axis] += ins[i]->dims()[j]; + continue; + } + PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j], + "Input tensors should have the same " + "elements except the specify axis.") + } + } + out->Resize(out_dims); + } +}; + +class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ConcatOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "the input tensors of concat operator.").AsDuplicable(); + AddOutput("Out", "the output tensor of concat operator."); + AddComment(R"DOC( + Join the input tensors along with the axis. + Examples: + Input[0] = [[1,2],[3,4]] + Input[1] = [[5,6]] + axis = 0 + Output = [[1,2], + [3,4], + [5,6]] + )DOC"); + AddAttr("axis", "The axis which the inputs will be joined with.") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(concat, ops::ConcatOp, ops::ConcatOpMaker) +REGISTER_OP_CPU_KERNEL(concat, + ops::ConcatKernel) diff --git a/paddle/operators/concat_op.cu b/paddle/operators/concat_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..38fee7473dbb2ba97fe95b6632db7a1749cf3bbe --- /dev/null +++ b/paddle/operators/concat_op.cu @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/concat_op.h" + +namespace ops = paddle::operators; +// TODO(Yancey1989) Add GPU kernel diff --git a/paddle/operators/concat_op.h b/paddle/operators/concat_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f977054fdf8aa0164db726b94a21c57f770dd674 --- /dev/null +++ b/paddle/operators/concat_op.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ConcatKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("X"); + auto* out = ctx.Output("Out"); + int64_t axis = static_cast(ctx.Attr("axis")); + size_t n = ins.size(); + size_t output_axis_dim = 0; + size_t before = 1, after = 1; + for (size_t i = 0; i < n; i++) { + output_axis_dim += ins[i]->dims()[axis]; + } + auto& input_zero = ins[0]; + for (int64_t i = 0; i < input_zero->dims().size(); i++) { + if (i == axis) { + continue; + } + if (i < axis) { + before *= input_zero->dims()[i]; + } else { + after *= input_zero->dims()[i]; + } + } + size_t output_offset = 0; + for (size_t i = 0; i < n; i++) { + auto& in = ins[i]; + auto axis_dim = in->dims()[axis]; + for (size_t j = 0; j < before; j++) { + size_t len = axis_dim * after * sizeof(T); + const T* src = in->data() + axis_dim * after * j; + T* out_data = out->mutable_data(platform::CPUPlace()); + T* dest = out_data + output_offset + output_axis_dim * after * j; + memcpy(dest, src, len); + } + output_offset += axis_dim * after; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 81448897e95eb05f4ce7de8683d98e05bade77cb..64fcbd93b6c4d5d9b36f2636c3ef4f7327f08d25 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -25,10 +25,6 @@ limitations under the License. */ #include "paddle/string/printf.h" #include "paddle/string/to_string.h" -#ifdef __GNUC__ -#include // for __cxa_demangle -#endif - #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" @@ -46,19 +42,6 @@ limitations under the License. */ namespace paddle { namespace platform { -namespace { -#ifdef __GNUC__ -inline std::string demangle(std::string name) { - int status = -4; // some arbitrary value to eliminate the compiler warning - std::unique_ptr res{ - abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; - return (status == 0) ? res.get() : name; -} -#else -inline std::string demangle(std::string name) { return name; } -#endif -} - struct EnforceNotMet : public std::exception { std::exception_ptr exp_; std::string err_str_; @@ -79,7 +62,7 @@ struct EnforceNotMet : public std::exception { Dl_info info; for (int i = 0; i < size; ++i) { if (dladdr(call_stack[i], &info)) { - auto demangled = demangle(info.dli_sname); + auto demangled = info.dli_sname; auto addr_offset = static_cast(call_stack[i]) - static_cast(info.dli_saddr); sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fe4af2c99ec46b46df7582dd98c2816d4d6be745..3958b53c22c383e5e2298bfdc4e8490d4148118f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -50,6 +50,7 @@ USE_OP(minus); USE_OP(cos_sim); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); +USE_CPU_ONLY_OP(concat); USE_OP(top_k); USE_OP(squared_l2_distance); USE_OP(sum); diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index e57f793ac42b19037e9ca43a5e4a3ac5447dc34c..2ac455d771bf78377ce4ee7d921393d3b3958e3c 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -30,6 +30,8 @@ Configuring cmake in /paddle/build ... -DCMAKE_BUILD_TYPE=Release -DWITH_DOC=OFF -DWITH_GPU=${WITH_GPU:-OFF} + -DWITH_MKLDNN=${WITH_MKLDNN:-ON} + -DWITH_MKLML=${WITH_MKLML:-ON} -DWITH_AVX=${WITH_AVX:-OFF} -DWITH_GOLANG=${WITH_GOLANG:-ON} -DWITH_SWIG_PY=ON @@ -50,6 +52,8 @@ cmake .. \ -DCMAKE_BUILD_TYPE=Release \ -DWITH_DOC=OFF \ -DWITH_GPU=${WITH_GPU:-OFF} \ + -DWITH_MKLDNN=${WITH_MKLDNN:-ON} \ + -DWITH_MKLML=${WITH_MKLML:-ON} \ -DWITH_AVX=${WITH_AVX:-OFF} \ -DWITH_GOLANG=${WITH_GOLANG:-ON} \ -DWITH_SWIG_PY=${WITH_SWIG_PY:-ON} \ diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 7e9112b43bf851575a3a798886d8b1b17e7c2017..356e1d8b6fa9173db33a340744afd8d513a83a96 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3748,8 +3748,8 @@ class SwitchOrderLayer(LayerBase): def __init__(self, name, inputs, reshape, **xargs): super(SwitchOrderLayer, self).__init__( name, 'switch_order', 0, inputs=inputs, **xargs) - self.config.reshape_conf.heightAxis.extend(reshape['height']) - self.config.reshape_conf.widthAxis.extend(reshape['width']) + self.config.reshape_conf.height_axis.extend(reshape['height']) + self.config.reshape_conf.width_axis.extend(reshape['width']) # Deprecated, use a new layer specific class instead diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index dc68c213da66ac680e6b14266cb5038a5ba73ec2..4b1d80d3db924bfa2ad0e081f785d8f5dd719fce 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -1223,7 +1223,8 @@ def detection_output_layer(input_loc, name=None): """ Apply the NMS to the output of network and compute the predict bounding - box location. + box location. The output of this layer could be None if there is no valid + bounding box. :param name: The Layer Name. :type name: basestring @@ -6460,6 +6461,7 @@ def switch_order_layer(input, return LayerOutput( name=name, layer_type=LayerType.SWITCH_ORDER_LAYER, + activation=act, parents=input, size=l.config.size) diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index 7589cc9917f26375d595e200245d5ba099bc38d7..e66bf67d7949057486eb54c46f39128fad5dae55 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -53,10 +53,13 @@ class BeginPass(object): class EndPass(WithMetric): """ Event On One Pass Training Complete. + To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')" + in your event_handler call back """ - def __init__(self, pass_id, evaluator): + def __init__(self, pass_id, evaluator, gm): self.pass_id = pass_id + self.gm = gm WithMetric.__init__(self, evaluator) @@ -73,10 +76,13 @@ class BeginIteration(object): class EndIteration(WithMetric): """ Event On One Batch Training Complete. + To get the output of a specific layer, add "event.gm.getLayerOutputs('predict_layer')" + in your event_handler call back """ - def __init__(self, pass_id, batch_id, cost, evaluator): + def __init__(self, pass_id, batch_id, cost, evaluator, gm): self.pass_id = pass_id self.batch_id = batch_id self.cost = cost + self.gm = gm WithMetric.__init__(self, evaluator) diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index 4e91924a50cf6401d4002510e940ddc84edbe61a..9e665adad2d3ad91d183c6815fbd7135ac4e8965 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -43,7 +43,6 @@ class OpDescCreationMethod(object): if len(args) != 0: raise ValueError("Only keyword arguments are supported.") op_desc = framework_pb2.OpDesc() - for input_parameter in self.__op_proto__.inputs: input_arguments = kwargs.get(input_parameter.name, []) if is_str(input_arguments): diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 2117fdf0d58520a008d2bd01d56d96dd248be025..2f6be105b6bce1200a29133f019523e6aee23895 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -35,4 +35,5 @@ py_test(test_lookup_table SRCS test_lookup_table.py) py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) py_test(test_sum_op SRCS test_sum_op.py) py_test(mnist SRCS mnist.py) +py_test(test_concat_op SRCS test_concat_op.py) py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index fdb06b7988935ebbe53f72f4eba89d75ac2502d4..ed838b5979d3bd5a3caa954074a692bbd24a5a07 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -15,7 +15,6 @@ def create_op(op_type): kwargs[in_name] = in_name for out_name in Operator.get_op_output_names(op_type): kwargs[out_name] = out_name - return Operator(op_type, **kwargs) diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 370f27eaf658dadbf7e82262c118140a10d15c41..99a114e45f1228551375982308b6f2cdc7aabb44 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -27,17 +27,27 @@ class OpTestMeta(type): places.append(core.GPUPlace(0)) for place in places: - for in_name in Operator.get_op_input_names(self.type): - if hasattr(self, "inputs") and in_name in self.inputs: - kwargs[in_name] = in_name - var = scope.new_var(in_name).get_tensor() - arr = self.inputs[in_name] - var.set_dims(arr.shape) - var.set(arr, place) + for in_name, in_dup in Operator.get_op_inputs(self.type): + if hasattr(self, 'inputs') and in_name in self.inputs: + kwargs[in_name] = [] + if in_dup: + arrays = self.inputs[in_name] + for index, arr in enumerate(arrays): + var = scope.new_var(in_name + str(index)) + tensor = var.get_tensor() + tensor.set_dims(arr.shape) + tensor.set(arr, place) + kwargs[in_name].append(in_name + str(index)) + else: + kwargs[in_name] = in_name + var = scope.new_var(in_name).get_tensor() + arr = self.inputs[in_name] + var.set_dims(arr.shape) + var.set(arr, place) else: kwargs[in_name] = "@EMPTY@" - for out_name in Operator.get_op_output_names(self.type): + for out_name, out_dup in Operator.get_op_outputs(self.type): if not hasattr(self, "outputs"): raise ValueError( "The test op must set self.outputs dict.") @@ -60,7 +70,7 @@ class OpTestMeta(type): ctx = core.DeviceContext.create(place) op.run(scope, ctx) - for out_name in Operator.get_op_output_names(self.type): + for out_name, out_dup in Operator.get_op_outputs(self.type): actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = self.outputs[out_name] self.assertTrue( diff --git a/python/paddle/v2/framework/tests/test_concat_op.py b/python/paddle/v2/framework/tests/test_concat_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd4c309748369ed6823168f89a7490803c25dfd --- /dev/null +++ b/python/paddle/v2/framework/tests/test_concat_op.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta + + +class TestConcatOp(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "concat" + x0 = np.random.random((2, 3, 2, 5)).astype('float32') + x1 = np.random.random((2, 3, 3, 5)).astype('float32') + x2 = np.random.random((2, 3, 4, 5)).astype('float32') + axis = 2 + self.inputs = {'X': [x0, x1, x2]} + self.attrs = {'axis': axis} + self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 0654a301049dcb347b79879076a869a0c14a07ae..ca95ef13bd440ac0ba3d46f6e4680d4d7aa94c42 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -174,13 +174,18 @@ class SGD(object): pass_id=pass_id, batch_id=batch_id, cost=cost, - evaluator=batch_evaluator)) + evaluator=batch_evaluator, + gm=self.__gradient_machine__)) self.__parameter_updater__.finishBatch(cost) batch_evaluator.finish() self.__parameter_updater__.finishPass() pass_evaluator.finish() - event_handler(v2_event.EndPass(pass_id, evaluator=pass_evaluator)) + event_handler( + v2_event.EndPass( + pass_id, + evaluator=pass_evaluator, + gm=self.__gradient_machine__)) self.__gradient_machine__.finish() def test(self, reader, feeding=None):