diff --git a/go/master/c/client.go b/go/master/c/client.go index 9a59337108d1aa33929abb480af686a96514655b..9a3960d59cd950ba68213ac53a51bfc4e68c0546 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -123,7 +123,8 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int } err := c.SetDataset(paths) if err != nil { - log.Error("error set dataset", log.Ctx{"error": err}) + log.Error("error set dataset", + log.Ctx{"error": err, "paths": paths}) return C.PADDLE_MASTER_ERROR } diff --git a/go/master/client.go b/go/master/client.go index 5d657548c9039dfdacf61dd1145deb9777596d9f..7bcf86955348fad14cbe86e2180539372fcb82cf 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -121,6 +121,7 @@ func (c *Client) StartGetRecords(passID int) { } func (c *Client) getRecords(passID int) { + i := 0 for { t, err := c.getTask(passID) if err != nil { @@ -130,12 +131,20 @@ func (c *Client) getRecords(passID int) { c.ch <- record{nil, err} break } - if err.Error() == ErrPassAfter.Error() { - // wait util last pass finishes - time.Sleep(time.Second * 3) - continue + + if i%60 == 0 { + log.Debug("getTask of passID error.", + log.Ctx{"error": err, "passID": passID}) + i = 0 } - log.Error("getTask error.", log.Ctx{"error": err}) + + // if err.Error() == ErrPassAfter.Error() + // wait util last pass finishes + // if other error such as network error + // wait to reconnect or task time out + time.Sleep(time.Second * 3) + i += 3 + continue } for _, chunk := range t.Chunks { diff --git a/go/master/client_test.go b/go/master/client_test.go index 79b9cc844d1ff938915a622bf19a7d772682becf..1963dbfd732605d3b2612f10a047c3a03faa53be 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -117,6 +117,7 @@ func TestNextRecord(t *testing.T) { if e != nil { panic(e) } + // test for n passes for pass := 0; pass < 10; pass++ { c.StartGetRecords(pass) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index a3357867530c110df16a5f3ec8c799735206cc71..239ae5e1233c7f5c506930df374b5d0cc8de7c8d 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -195,6 +195,14 @@ std::vector vectorize(const DDim& ddim) { return result; } +// NOTE: framework::vectorize converts to type int64_t +// which does not fit cudnn inputs. +std::vector vectorize2int(const DDim& ddim) { + std::vector temp = vectorize(ddim); + std::vector result(temp.begin(), temp.end()); + return result; +} + struct ProductVisitor : public boost::static_visitor { template int64_t operator()(const Dim& dim) { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 4a871bb0a91ed4050847509cc3f24218bcd57142..2a5e2d2b6948b045642dbac5e83992a048ecb63d 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx); void set(DDim& dim, int idx, int val); std::vector vectorize(const DDim& ddim); +std::vector vectorize2int(const DDim& ddim); int64_t product(const DDim& ddim); diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp index f577616230be65e9581cf8f3ed5f63a77c7c3e21..9b0ae20f089e34a719883bc65e88e33ab9334e39 100644 --- a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp @@ -216,17 +216,13 @@ void MKLDNNBatchNormLayer::resetFwdPD( } auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_); pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_)); - // TODO(TJ): use check macro - CHECK(out); - CHECK(out->getPrimitiveDesc() == pd->dst_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc()); if (wgt) { - CHECK(wgt->getPrimitiveDesc() == pd->weights_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(wgt, pd->weights_primitive_desc()); } if (passType_ != PASS_TEST || useGlobalStats_) { - CHECK(mean_); - CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc()); - CHECK(var_); - CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(mean_, pd->mean_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(var_, pd->variance_primitive_desc()); } } @@ -283,19 +279,14 @@ void MKLDNNBatchNormLayer::resetBwdPD( if (in == nullptr) { return; } - CHECK(out); - CHECK(out->getPrimitiveDesc() == in->getPrimitiveDesc()); + CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc()); auto md = in->getMemoryDesc(); auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_); pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_)); - // TODO(TJ): use check macro - CHECK(wgt); - CHECK(wgt->getPrimitiveDesc() == pd->diff_weights_primitive_desc()); CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc()); - CHECK(mean_); - CHECK(mean_->getPrimitiveDesc() == pd->mean_primitive_desc()); - CHECK(var_); - CHECK(var_->getPrimitiveDesc() == pd->variance_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(mean_, pd->mean_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ(var_, pd->variance_primitive_desc()); } void MKLDNNBatchNormLayer::resetBwdPipeline( diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 83f4e4e6151d727b3e6cf367bb7ecae55dd7df73..b8120eda1e2dadab943869a05546351a369af6fd 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -262,12 +262,15 @@ void MKLDNNConvLayer::resetBwdWgtPD( padR, padKind); pd.reset(new conv_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_)); - CHECK(pd->src_primitive_desc() == inVal_->getPrimitiveDesc()) - << "primitive desc of in value should equal"; - CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc()) - << "primitive desc of out grad should equal the out value"; - CHECK(pd->diff_weights_primitive_desc() == wgtVal_->getPrimitiveDesc()) - << "primitive desc of weight grad should equal the weight value"; + CHECK_PRIMITIVE_DESC_EQ(inVal_, pd->src_primitive_desc()); + CHECK_PRIMITIVE_DESC_EQ( + outVal_, + pd->diff_dst_primitive_desc(), + "primitive desc of out value and grad should be equal"); + CHECK_PRIMITIVE_DESC_EQ( + wgtVal_, + pd->diff_weights_primitive_desc(), + "primitive desc of weight value and grad should be equal"); } void MKLDNNConvLayer::resetBwdDataPD( @@ -292,10 +295,14 @@ void MKLDNNConvLayer::resetBwdDataPD( padR, padding_kind::zero); pd.reset(new conv_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_)); - CHECK(pd->diff_src_primitive_desc() == inVal_->getPrimitiveDesc()) - << "primitive desc of in grad should equal the in value"; - CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc()) - << "primitive desc of out grad should equal"; + CHECK_PRIMITIVE_DESC_EQ( + inVal_, + pd->diff_src_primitive_desc(), + "primitive desc of in value and grad should be equal"); + CHECK_PRIMITIVE_DESC_EQ( + outVal_, + pd->diff_dst_primitive_desc(), + "primitive desc of out value and grad should be equal"); } void MKLDNNConvLayer::resetBwdBuffers( @@ -310,17 +317,20 @@ void MKLDNNConvLayer::resetBwdBuffers( resetWithMatrix( wgt, weight_->getWGrad(), wgtPD->diff_weights_primitive_desc()); - CHECK(wgtVal_ != nullptr && - wgt->getPrimitiveDesc() == wgtVal_->getPrimitiveDesc()) - << "primitive desc of weight grad and value should be equal"; + CHECK_PRIMITIVE_DESC_EQ( + wgtVal_, + wgt->getPrimitiveDesc(), + "primitive desc of weight grad and value should be equal"); bias = nullptr; if (biases_ && biases_->getWGrad()) { resetWithMatrix( bias, biases_->getWGrad(), wgtPD->diff_bias_primitive_desc()); - CHECK(bias && biasVal_ && - bias->getPrimitiveDesc() == biasVal_->getPrimitiveDesc()) - << "primitive desc of bias grad should equal the bias value"; + CHECK(bias); + CHECK_PRIMITIVE_DESC_EQ( + biasVal_, + bias->getPrimitiveDesc(), + "primitive desc of bias grad and value should be equal"); } if (dataPD == nullptr) { diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 6bb19976b5552fcd2e420f03de45c77a90ffb9d2..663a10509857ec9fb487c1cda1621bdfac1250ac 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -235,8 +235,7 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in, in = MKLDNNMatrix::create(intPD, inMat); Argument& arg = input->getOutput(this->getName()); arg.grad = std::dynamic_pointer_cast(in); - CHECK(inVal_); - CHECK(inVal_->getPrimitiveDesc() == intPD) << "the primitive desc must equal"; + CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD); if (inputIsOnlyMKLDNN()) { return; } @@ -250,8 +249,7 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in, CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat())) << "should have external input value and the format must be nchw(nc)"; extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat); - CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD) - << "should have internal input value and primitive desc must equal"; + CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD); in = MKLDNNMatrix::create(intPD); cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_); CHECK(cvtInGrad_); @@ -277,8 +275,7 @@ void MKLDNNLayer::resetOutGrad(MKLDNNMatrixPtr& out, CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat())) << "should have external output value and the format must be nchw(nc)"; extOutGrad_ = MKLDNNMatrix::create(extOutVal_->getPrimitiveDesc(), outMat); - CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD) - << "should have internal output value and primitive desc must equal"; + CHECK_PRIMITIVE_DESC_EQ(outVal_, intPD); out = MKLDNNMatrix::create(intPD); cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out); CHECK(cvtOutGrad_); diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index 2b62d4e11ac7276924947ab47360ffca84240aea..5f5b819017b83579ce58522198b3f13311297d42 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -24,6 +24,12 @@ namespace paddle { class MKLDNNMatrix; typedef std::shared_ptr MKLDNNMatrixPtr; +#define CHECK_PRIMITIVE_DESC_EQ(MAT, PD, ...) \ + CHECK(MAT) << " can not be empty."; \ + CHECK(MAT->getPrimitiveDesc() == PD) \ + << #MAT "->getPrimitiveDesc() and " #PD " should be equal.\n " \ + << "" __VA_ARGS__; + /** * @brief MKLDNN Matrix. * diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index eaa9884443386cebdf686e25143d99fec17646f2..5b0097a4eb33ecfa066ceade2172af84e3ee44a1 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -75,6 +75,13 @@ function(op_library TARGET) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_OP(conv2d);\n") endif() + + # pool_cudnn_op contains several operators + if ("${TARGET}" STREQUAL "pool_cudnn_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") + endif() # save_restore_op contains several operators if ("${TARGET}" STREQUAL "save_restore_op") diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf3dbc5d10c66cbb344ca8cf8c46432eabef4a07 --- /dev/null +++ b/paddle/operators/auc_op.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/auc_op.h" + +namespace paddle { +namespace operators { + +class AucOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Inference"), + "Input of Inference must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input of Label must be initialized."); + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ(inference_dim, label_dim, + "inference and label should have same shape"); + + ctx->SetOutputDim("AUC", {1}); + ctx->ShareLoD("Inference", /*->*/ "AUC"); + } +}; + +class AucOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Inference", + "A floating point tensor of arbitrary shape and whose values" + "are in the range [0, 1]."); + AddInput("Label", + "A tensor whose shape matches " + "Inference. Will be cast to bool."); + // TODO(typhoonzero): support weight input + AddOutput("AUC", + "A scalar representing the " + "current area-under-curve."); + + AddAttr("curve", "Curve type, can be 'ROC' or 'PR'.") + .SetDefault("ROC"); + AddAttr("num_thresholds", + "The number of thresholds to use when discretizing the" + " roc curve.") + .SetDefault(200); + + AddComment( + R"DOC(Computes the AUC according forward output and label. +Best to use for binary classification evaluations. + +If input label contains values other than 0 and 1, it will be cast +to bool. + +You can find the definations here: +https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve + +Possible curves are: +- ROC: Receiver operating characteristic +- PR: Precision Recall +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker); +REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel); diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h new file mode 100644 index 0000000000000000000000000000000000000000..be6ef29d5f6cff5b9ebdf7d8564b2e2792c3b5cb --- /dev/null +++ b/paddle/operators/auc_op.h @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenVector = framework::EigenVector; + +template +class AucKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* inference = ctx.Input("Inference"); + auto* label = ctx.Input("Label"); + auto* auc = ctx.Output("AUC"); + + float* auc_data = auc->mutable_data(ctx.GetPlace()); + + std::string curve = ctx.Attr("curve"); + int num_thresholds = ctx.Attr("num_thresholds"); + std::vector thresholds_list; + thresholds_list.reserve(num_thresholds); + for (int i = 1; i < num_thresholds - 1; i++) { + thresholds_list[i] = (float)i / (num_thresholds - 1); + } + const float kEpsilon = 1e-7; + thresholds_list[0] = 0.0f - kEpsilon; + thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; + + size_t num_samples = inference->numel(); + + const T* inference_data = inference->data(); + Tensor label_casted; + label_casted.Resize(label->dims()); + bool* label_casted_data = label_casted.mutable_data(ctx.GetPlace()); + + const int* label_data = label->data(); + // cast label_data to bool + for (size_t i = 0; i < num_samples; i++) { + label_casted_data[i] = static_cast(label_data[i]); + } + + // Create local tensor for storing the curve: TP, FN, TN, FP + // TODO(typhoonzero): use eigen op to caculate these values. + Tensor true_positive, false_positive, true_negative, false_negative; + + true_positive.Resize({num_thresholds}); + false_negative.Resize({num_thresholds}); + true_negative.Resize({num_thresholds}); + false_positive.Resize({num_thresholds}); + + int* tp_data = true_positive.mutable_data(ctx.GetPlace()); + int* fn_data = false_negative.mutable_data(ctx.GetPlace()); + int* tn_data = true_negative.mutable_data(ctx.GetPlace()); + int* fp_data = false_positive.mutable_data(ctx.GetPlace()); + + for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { + // caculate TP, FN, TN, FP for current thresh + int tp = 0, fn = 0, tn = 0, fp = 0; + for (size_t i = 0; i < num_samples; i++) { + if (label_casted_data[i]) { + if (inference_data[i] >= (thresholds_list[idx_thresh])) { + tp++; + } else { + fn++; + } + } else { + if (inference_data[i] >= (thresholds_list[idx_thresh])) { + fp++; + } else { + tn++; + } + } + } + // store rates + tp_data[idx_thresh] = tp; + fn_data[idx_thresh] = fn; + tn_data[idx_thresh] = tn; + fp_data[idx_thresh] = fp; + } + // epsilon to avoid divide by zero. + float epsilon = 1e-6; + // Riemann sum to caculate auc. + Tensor tp_rate, fp_rate, rec_rate; + tp_rate.Resize({num_thresholds}); + fp_rate.Resize({num_thresholds}); + rec_rate.Resize({num_thresholds}); + float* tp_rate_data = tp_rate.mutable_data(ctx.GetPlace()); + float* fp_rate_data = fp_rate.mutable_data(ctx.GetPlace()); + float* rec_rate_data = rec_rate.mutable_data(ctx.GetPlace()); + for (int i = 0; i < num_thresholds; i++) { + tp_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon); + fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); + rec_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); + } + *auc_data = 0.0f; + if (curve == "ROC") { + for (int i = 0; i < num_thresholds - 1; i++) { + auto dx = fp_rate_data[i] - fp_rate_data[i + 1]; + auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f; + *auc_data = *auc_data + dx * y; + } + } else if (curve == "PR") { + for (int i = 1; i < num_thresholds; i++) { + auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; + auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; + *auc_data = *auc_data + dx * y; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/conv_cudnn_op.cu b/paddle/operators/conv_cudnn_op.cu index e34d59374075427a6d0b54720d7a22b7c89635cc..d115850e2b651e20d82ad6028648c6a88439c9d7 100644 --- a/paddle/operators/conv_cudnn_op.cu +++ b/paddle/operators/conv_cudnn_op.cu @@ -31,16 +31,6 @@ using CUDADeviceContext = platform::CUDADeviceContext; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; -// NOTE: framework::vectorize converts to type int64_t -// which does not fit cudnn inputs. -std::vector Dims2Vector(const framework::DDim& dims) { - std::vector ret; - for (int i = 0; i < dims.size(); i++) { - ret.push_back(dims[i]); - } - return ret; -} - template class CudnnConvOpKernel : public framework::OpKernel { public: @@ -68,12 +58,12 @@ class CudnnConvOpKernel : public framework::OpKernel { ScopedConvolutionDescriptor conv_desc; DataLayout layout = DataLayout::kNCHW; - cudnnTensorDescriptor_t cudnn_input_desc = - input_desc.descriptor(layout, Dims2Vector(input->dims()), groups); - cudnnTensorDescriptor_t cudnn_output_desc = - output_desc.descriptor(layout, Dims2Vector(output->dims()), groups); - cudnnFilterDescriptor_t cudnn_filter_desc = - filter_desc.descriptor(layout, Dims2Vector(filter->dims()), groups); + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims()), groups); + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims()), groups); cudnnConvolutionDescriptor_t cudnn_conv_desc = conv_desc.descriptor(paddings, strides, dilations); @@ -156,13 +146,13 @@ class CudnnConvGradOpKernel : public framework::OpKernel { ScopedConvolutionDescriptor conv_desc; DataLayout layout = DataLayout::kNCHW; - cudnnTensorDescriptor_t cudnn_input_desc = - input_desc.descriptor(layout, Dims2Vector(input->dims()), groups); + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims()), groups); cudnnTensorDescriptor_t cudnn_output_grad_desc = - output_grad_desc.descriptor(layout, Dims2Vector(output_grad->dims()), - groups); - cudnnFilterDescriptor_t cudnn_filter_desc = - filter_desc.descriptor(layout, Dims2Vector(filter->dims()), groups); + output_grad_desc.descriptor( + layout, framework::vectorize2int(output_grad->dims()), groups); + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor( + layout, framework::vectorize2int(filter->dims()), groups); cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; @@ -192,7 +182,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel { auto handle = ctx.cuda_device_context().cudnn_handle(); if (input_grad) { cudnn_input_grad_desc = input_grad_desc.descriptor( - layout, Dims2Vector(input_grad->dims()), groups); + layout, framework::vectorize2int(input_grad->dims()), groups); PADDLE_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( handle, cudnn_filter_desc, @@ -213,7 +203,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel { if (filter_grad) { cudnn_filter_grad_desc = filter_grad_desc.descriptor( - layout, Dims2Vector(filter_grad->dims()), groups); + layout, framework::vectorize2int(filter_grad->dims()), groups); PADDLE_ENFORCE( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, diff --git a/paddle/operators/huber_loss_op.cc b/paddle/operators/huber_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d9449f5ca50dab8d2a7928c4311ec2d66b47904 --- /dev/null +++ b/paddle/operators/huber_loss_op.cc @@ -0,0 +1,122 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/huber_loss_op.h" + +namespace paddle { +namespace operators { + +class HuberLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ(x_dims, y_dims); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, + "The rank of Input(X) must be 2 and the shape is " + "[batch_size, 1]."); + PADDLE_ENFORCE_EQ(x_dims[1], 1, + "Each row of Input(X) contains a real value, " + "so the 2nd dimension of Input(X) must be 1."); + + ctx->SetOutputDim("Residual", x_dims); + ctx->SetOutputDim("Out", {x_dims[0], 1}); + ctx->ShareLoD("X", "Out"); + } +}; + +template +class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HuberLossOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input value of huber loss op." + "X is a 2-D tensor with shape [batch_size, 1]."); + AddInput("Y", + "The target value of huber loss op." + "Y is a 2-D tensor with shape [batch_size, 1]."); + AddOutput("Residual", + "Intermediate tensor to cache residual value between Y and X." + "The shape is same as Input(X) and will be reused in backward.") + .AsIntermediate(); + AddOutput("Out", + "The output tensor with shape [batch_size, 1] which represents " + "the huber loss."); + AddAttr("delta", "Hyper parameter in huber loss."); + AddComment(R"DOC( +Huber loss is a loss function used in robust regression. We define X as the +input value and Y as the target value. Huber loss can evaluate the fitness of +X to Y. Different from MSE loss, Huber loss is more robust for outliers. The +shape of X and Y are [batch_size, 1]. The equation is: + +L_{\delta}(y, f(x)) = +\begin{cases} +0.5 * (y - f(x))^2, \quad |y - f(x)| \leq \delta \\ +\delta * (|y - f(x)| - 0.5 * \delta), \quad otherwise +\end{cases} + +)DOC"); + } +}; + +class HuberLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Residual"), + "Input(Residual) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto residual_dims = ctx->GetInputDim("Residual"); + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(residual_dims, x_dims); + PADDLE_ENFORCE_EQ(out_grad_dims, x_dims); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker, + huber_loss_grad, ops::HuberLossGradOp); +REGISTER_OP_CPU_KERNEL(huber_loss, + ops::HuberLossKernel); +REGISTER_OP_CPU_KERNEL( + huber_loss_grad, + ops::HuberLossGradKernel); diff --git a/paddle/operators/huber_loss_op.cu b/paddle/operators/huber_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..317321dc6c495f6e9a8808d841c71bfa26b754d0 --- /dev/null +++ b/paddle/operators/huber_loss_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/huber_loss_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(huber_loss, + ops::HuberLossKernel); +REGISTER_OP_GPU_KERNEL( + huber_loss_grad, + ops::HuberLossGradKernel); diff --git a/paddle/operators/huber_loss_op.h b/paddle/operators/huber_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4e7bc5543226e19fe0d6190171cdd9c2b3d2d985 --- /dev/null +++ b/paddle/operators/huber_loss_op.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +struct HuberLossForward { + HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val <= delta) { + return static_cast(0.5) * val * val; + } else { + return delta * (abs_val - static_cast(0.5) * delta); + } + } + + T delta; +}; + +template +class HuberLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* in1 = context.Input("Y"); + auto* out0 = context.Output("Residual"); + auto* out1 = context.Output("Out"); + auto delta = static_cast(context.Attr("delta")); + auto place = context.GetEigenDevice(); + + auto x = EigenVector::Flatten(*in0); + auto y = EigenVector::Flatten(*in1); + out0->mutable_data(context.GetPlace()); + auto residual = EigenVector::Flatten(*out0); + residual.device(place) = y - x; + out1->mutable_data(context.GetPlace()); + auto loss = EigenVector::Flatten(*out1); + loss.device(place) = residual.unaryExpr(HuberLossForward(delta)); + } +}; + +template +struct HuberLossBackward { + HOSTDEVICE HuberLossBackward(const T& delta, T sign) + : sign(sign), delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val <= delta) { + return sign * val; + } else { + if (val > 0) { + return sign * delta; + } else { + return -1 * sign * delta; + } + } + } + + T sign; + T delta; +}; + +template +class HuberLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("Residual"); + auto* in1 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + auto* out1 = context.Output(framework::GradVarName("Y")); + auto delta = static_cast(context.op().Attr("delta")); + auto place = context.GetEigenDevice(); + + auto residual = EigenVector::Flatten(*in0); + auto out_grad = EigenVector::Flatten(*in1); + + if (out0) { + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenVector::Flatten(*out0); + x_grad.device(place) = + out_grad * residual.unaryExpr(HuberLossBackward(delta, -1.0)); + } + + if (out1) { + out1->mutable_data(context.GetPlace()); + auto y_grad = EigenVector::Flatten(*out1); + y_grad.device(place) = + out_grad * residual.unaryExpr(HuberLossBackward(delta, 1.0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/pool_cudnn_op.cc b/paddle/operators/pool_cudnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f962d9e3e6abde14ce21eb0102f10d139fdb160e --- /dev/null +++ b/paddle/operators/pool_cudnn_op.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_cudnn_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP(pool2d_cudnn, ops::PoolOp, ops::Pool2dOpMaker, pool2d_cudnn_grad, + ops::PoolOpGrad); + +REGISTER_OP_CPU_KERNEL(pool2d_cudnn, + ops::PoolKernel); +REGISTER_OP_CPU_KERNEL(pool2d_cudnn_grad, + ops::PoolGradKernel) diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc29be18e76fde19c10c32e0299c395a150d8c40 --- /dev/null +++ b/paddle/operators/pool_cudnn_op.cu @@ -0,0 +1,152 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/pool_cudnn_op.h" +#include "paddle/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor; +using DataLayout = platform::DataLayout; +using PoolingMode = platform::PoolingMode; + +template +class PoolCudnnOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + const Tensor *input = ctx.Input("X"); + Tensor *output = ctx.Output("Out"); + + const T *input_data = input->data(); + T *output_data = output->mutable_data(ctx.GetPlace()); + + std::string pooling_type = ctx.Attr("poolingType"); + std::vector ksize = ctx.Attr>("ksize"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + if (ctx.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) { + ksize[i] = static_cast(input->dims()[i + 2]); + } + } + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedPoolingDescriptor pool_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + + PoolingMode pooling_mode; + if (pooling_type == "max") { + pooling_mode = PoolingMode::kMaximum; + } else { + pooling_mode = PoolingMode::kAverage; + } + + cudnnPoolingDescriptor_t cudnn_pool_desc = + pool_desc.descriptor(pooling_mode, ksize, paddings, strides); + + // ------------------- cudnn pool algorithm --------------------- + auto handle = ctx.cuda_device_context().cudnn_handle(); + T alpha = 1.0f, beta = 0.0f; + + PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward( + handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, + cudnn_output_desc, output_data)); + } +}; + +template +class PoolCudnnGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use GPUPlace."); + + const Tensor *input = ctx.Input("X"); + const Tensor *output = ctx.Input("Out"); + const Tensor *output_grad = + ctx.Input(framework::GradVarName("Out")); + Tensor *input_grad = ctx.Output(framework::GradVarName("X")); + + std::string pooling_type = ctx.Attr("poolingType"); + std::vector ksize = ctx.Attr>("ksize"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + + if (ctx.Attr("globalPooling")) { + for (size_t i = 0; i < ksize.size(); ++i) + ksize[i] = static_cast(input->dims()[i + 2]); + } + + const T *input_data = input->data(); + const T *output_data = output->data(); + const T *output_grad_data = output_grad->data(); + + // ------------------- cudnn descriptors --------------------- + ScopedTensorDescriptor input_desc; + ScopedTensorDescriptor output_desc; + ScopedPoolingDescriptor pool_desc; + DataLayout layout = DataLayout::kNCHW; + + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( + layout, framework::vectorize2int(input->dims())); + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( + layout, framework::vectorize2int(output->dims())); + + PoolingMode pooling_mode; + if (pooling_type == "max") { + pooling_mode = PoolingMode::kMaximum; + } else { + pooling_mode = PoolingMode::kAverage; + } + + cudnnPoolingDescriptor_t cudnn_pool_desc = + pool_desc.descriptor(pooling_mode, ksize, paddings, strides); + + // ------------------- cudnn pool algorithm --------------------- + auto handle = ctx.cuda_device_context().cudnn_handle(); + T alpha = 1.0f, beta = 0.0f; + + if (input_grad) { + T *input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + math::SetConstant set_zero; + set_zero(ctx.device_context(), input_grad, static_cast(0)); + + PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward( + handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, + cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data, + &beta, cudnn_input_desc, input_grad_data)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(pool2d_cudnn, ops::PoolCudnnOpKernel); +REGISTER_OP_GPU_KERNEL(pool2d_cudnn_grad, ops::PoolCudnnGradOpKernel); diff --git a/paddle/operators/pool_cudnn_op.h b/paddle/operators/pool_cudnn_op.h new file mode 100644 index 0000000000000000000000000000000000000000..5adf27f5bccae8542719612320bc6dbe21007634 --- /dev/null +++ b/paddle/operators/pool_cudnn_op.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/pool_op.h" + +namespace paddle { +namespace operators {} // namespace operators +} // namespace paddle diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 898ae2fb62799ef2b51e7cf8b116f8e1771ef057..c4ab29e4d5f7c02d97a2185a58fdcd48de822d2d 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -29,7 +29,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { auto in_x_dims = ctx->GetInputDim("X"); - std::string pooling_type = ctx->Attrs().Get("pooling_type"); + std::string pooling_type = ctx->Attrs().Get("poolingType"); std::vector ksize = ctx->Attrs().Get>("ksize"); std::vector strides = ctx->Attrs().Get>("strides"); std::vector paddings = ctx->Attrs().Get>("paddings"); @@ -37,7 +37,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, "Pooling intput should be 4-D or 5-D tensor."); - if (ctx->Attrs().Get("global_pooling")) { + if (ctx->Attrs().Get("globalPooling")) { ksize.resize(static_cast(in_x_dims.size()) - 2); for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x_dims[i + 2]); @@ -80,34 +80,30 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, "the number of channels, H and W is the height and " "width of feature."); - AddAttr("pooling_type", - "Pooling_type of pooling operator." - "Str constant equal to 'max' or 'avg'.") + AddAttr("poolingType", + "(string), pooling type, can be \"max\" for max-pooling " + "and \"avg\" for average-pooling.") .InEnum({"max", "avg"}); - AddAttr>( "ksize", - "The pooling window size(height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + // TypedAttrChecker don't support vector type.) + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); - AddAttr>("strides", - "The strides(height, width) of pooling window." - "Default {1,1}.") + AddAttr>( + "strides", + "(vector, default:{1, 1}), strides(height, width) of pooling operator.") .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr>("paddings", - "The zero padding(height, width) size on both sides" - "Default {0,0}.") + // TypedAttrChecker don't support vector type.) + AddAttr>( + "paddings", + "(vector defalut:{0,0}), paddings(height, width) of pooling operator.") .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( The pooling2d operation calculates the output based on @@ -145,33 +141,29 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, "the number of channels, D, H and W is the depth, height and " "width of feature."); - AddAttr("pooling_type", - "PoolingType of pooling operator." - "Str constant equal to 'max' or 'avg'.") + AddAttr("poolingType", + "(string), pooling type, can be \"max\" for max-pooling " + "and \"avg\" for average-pooling.") .InEnum({"max", "avg"}); - AddAttr>( "ksize", - "The pooling window size(depth, height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(depth, height, width) of pooling " + "operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); AddAttr>("strides", - "Strides(depth, height, width) of pooling operator." - "Default {1,1,1}.") + "(vector, default:{1,1,1}), strides(depth, height, " + "width) of pooling operator.") .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, // TypedAttrChecker don't support vector type.) - AddAttr>( - "paddings", - "Paddings(depth, height, width) of pooling operator." - "Default {0,0,0}.") + AddAttr>("paddings", + "(vector defalut:{0,0,0}), paddings(depth, height, " + "width) of pooling operator.") .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, // TypedAttrChecker don't support vector type.) diff --git a/paddle/operators/pool_op.h b/paddle/operators/pool_op.h index ada956501918cc92a2d30ebb8d0c42453acd2839..ba8edc9cf60bcf90204ed11fa4fe1d408ad82d40 100644 --- a/paddle/operators/pool_op.h +++ b/paddle/operators/pool_op.h @@ -57,11 +57,11 @@ class PoolKernel : public framework::OpKernel { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); - std::string pooling_type = context.Attr("pooling_type"); + std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -117,12 +117,12 @@ class PoolGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - std::string pooling_type = context.Attr("pooling_type"); + std::string pooling_type = context.Attr("poolingType"); std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x->dims()[i + 2]); } diff --git a/paddle/operators/pool_with_index_op.cc b/paddle/operators/pool_with_index_op.cc index 29d0322a27b71fe8d335703e228969c084f5139f..ea21845751bee523fbbb85f7fdbeea7bcc586997 100644 --- a/paddle/operators/pool_with_index_op.cc +++ b/paddle/operators/pool_with_index_op.cc @@ -44,7 +44,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5, "Pooling intput should be 4-D or 5-D tensor."); - if (ctx->Attrs().Get("global_pooling")) { + if (ctx->Attrs().Get("globalPooling")) { ksize.resize(static_cast(in_x_dims.size()) - 2); for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = static_cast(in_x_dims[i + 2]); @@ -105,28 +105,24 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "ksize", - "The pooling window size(height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(height, width) of pooling operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + // TypedAttrChecker don't support vector type.) + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); - AddAttr>("strides", - "The strides(height, width) of pooling window." - "Default {1,1}.") + AddAttr>( + "strides", + "(vector, default:{1, 1}), strides(height, width) of pooling operator.") .SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddAttr>( "paddings", - "The zero padding(height, width) size on both sides" - "Default {0,0}.") + "(vector defalut:{0,0}), paddings(height, width) of pooling operator.") .SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( The maxPooling2d with index operation calculates the output and the mask @@ -176,29 +172,25 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>( "ksize", - "The pooling window size(depth, height, width) of pooling operator." - "If global_pooling = true, ksize is ignored and need not be " + "(vector ), the pooling window size(depth, height, width) of pooling " + "operator." + "If globalPooling = true, ksize is ignored and need not be " "specified."); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr( - "global_pooling", - "Whether to use the global_pooling." - "Bool constant equal to false or true." - "Default false." - "If global_pooling = true, ksize is ignored and need not be specified.") + // TypedAttrChecker don't support vector type.) + AddAttr("globalPooling", + "(bool default: false), whether to use the global pooling." + "If globalPooling = true, ksize is ignored.") .SetDefault(false); - AddAttr>( - "strides", - "Strides(depth, height, width) of pooling operator." - "Default {1,1,1}.") + AddAttr>("strides", + "(vector, default:{1,1,1}), strides(depth, " + "height, width) of pooling operator.") .SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) - AddAttr>( - "paddings", - "Paddings(depth, height, width) of pooling operator." - "Default {0,0,0}.") + // TypedAttrChecker don't support vector type.) + AddAttr>("paddings", + "(vector defalut:{0,0,0}), paddings(depth, " + "height, width) of pooling operator.") .SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently, - // TypedAttrChecker don't support vector type.) + // TypedAttrChecker don't support vector type.) AddComment(R"DOC( The maxpooling3d with index operation calculates the output and the mask diff --git a/paddle/operators/pool_with_index_op.h b/paddle/operators/pool_with_index_op.h index 455c453efcd15bf0150bbd3de83d50729f338b4b..01b961ca8295f723bea7335e43ec5ab100dfc65c 100644 --- a/paddle/operators/pool_with_index_op.h +++ b/paddle/operators/pool_with_index_op.h @@ -35,7 +35,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x->dims()[i + 2]); } @@ -70,7 +70,7 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel { std::vector ksize = context.Attr>("ksize"); std::vector strides = context.Attr>("strides"); std::vector paddings = context.Attr>("paddings"); - if (context.Attr("global_pooling")) { + if (context.Attr("globalPooling")) { for (size_t i = 0; i < ksize.size(); ++i) { ksize[i] = static_cast(in_x_grad->dims()[i + 2]); } diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index e3f5d509a85537669237b8fd0ed44efe8abb6874..6d600c27271c660f0cf933e8bd05455df61740ec 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -47,6 +47,15 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( SequencePoolOp pools features of all time-steps of each instance. + It supports six pooling strategy: + - AVERAGE: Out[i] = average_{for each instance in i-th sequence}{X[i]} + - SUM: Out[i] = sum_{for each instance in i-th sequence}{X[i]} + - SQRT: Out[i] = sum_{for each instance in i-th sequence}{X[i]} + / sqrt(i-th sequence length) + - LAST: Out[i] = last instance in i-th sequence X[i] + - FIRST: Out[i] = first instance in i-th sequence X[i] + - MAX: Out[i] = max_{for each instance in i-th sequence}{X[i]} + For a mini-batch of 3 variable-length sentences, containing 2, 3, and 2 time-steps: Assume X is a [7,M,N] LoDTensor, and X->lod()[0] = [0, 2, 5, 7], 7=2+3+2. diff --git a/paddle/operators/sequence_pool_op.h b/paddle/operators/sequence_pool_op.h index 0de6cafe9ca83f09636a69b5579d19afde1c73b5..ead30e8e90b25165664b690491895ae68c8fc0ab 100644 --- a/paddle/operators/sequence_pool_op.h +++ b/paddle/operators/sequence_pool_op.h @@ -82,6 +82,9 @@ class SequencePoolKernel : public framework::OpKernel { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / std::sqrt(static_cast(h)); break; + case MAX: + out_e.device(place) = in_e.maximum(Eigen::array({{0}})); + break; case LAST: out_e.device(place) = in_e.chip(h - 1, 0); break; @@ -100,8 +103,8 @@ class SequencePoolGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); - auto* out_g = context.Input(framework::GradVarName("Out")); auto* in_g = context.Output(framework::GradVarName("X")); + auto* out_g = context.Input(framework::GradVarName("Out")); int strategy = context.Attr("strategy"); auto dims = in->dims(); @@ -135,6 +138,22 @@ class SequencePoolGradKernel : public framework::OpKernel { in_g_e.device(place) = (out_g_e / std::sqrt(static_cast(h))).broadcast(bcast); break; + case MAX: { + auto in_t = + in->Slice(static_cast(lod[i]), static_cast(lod[i + 1])); + Eigen::Map> + in_t_map(in_t.data(), h, w); + int row_id; + Eigen::array extents = {1, 1}; + for (int col_id = 0; col_id < w; col_id++) { + in_t_map.col(col_id).maxCoeff(&row_id); + Eigen::array in_offsets = {row_id, col_id}; + Eigen::array out_offsets = {0, col_id}; + in_g_e.slice(in_offsets, extents).device(place) = + out_g_e.slice(out_offsets, extents); + } + break; + } case LAST: in_g_e.chip(h - 1, 0).device(place) = out_g_e; break; diff --git a/python/paddle/utils/merge_model.py b/python/paddle/utils/merge_model.py new file mode 100644 index 0000000000000000000000000000000000000000..48e5087cc281bd3a3d0b4a403372456ebbf39c62 --- /dev/null +++ b/python/paddle/utils/merge_model.py @@ -0,0 +1,72 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gzip +import struct +import os + +from paddle.trainer_config_helpers.layers import LayerOutput +from paddle.v2.parameters import Parameters +from paddle.proto import ModelConfig_pb2 +from paddle.v2.topology import Topology + + +def merge_v2_model(net, param_file, output_file): + '''Integrate the model config and model parameters into one file. + + The model configuration file describes the model structure which + ends with .py. The parameters file stores the parameters of the model + which ends with .tar.gz. + + @param net The output layer of the network. + @param param_file Path of the model parameters(.tar.gz) which is stored by v2 api. + @param output_file Path of the merged file which will be generated. + + Usage: + + from paddle.util.merge_model import merge_v2_model + # import your network configuration + from mobilenet import mobile_net + + net = mobile_net(3*224*224, 102) + param_file = './param_pass_00000.tar.gz' + output_file = './output.paddle' + + merge_v2_model(net, param_file, output_file) + + ''' + + assert isinstance(net, LayerOutput), \ + "The net should be the output of the network" + assert os.path.exists(param_file), \ + "The model parameters file %s does not exists " % (param_file) + + model_proto = Topology(net).proto() + assert isinstance(model_proto, ModelConfig_pb2.ModelConfig) + + with gzip.open(param_file) as f: + params = Parameters.from_tar(f) + + if os.path.exists(output_file): + os.remove(output_file) + + with open(output_file, 'w') as f: + param_names = [param.name for param in model_proto.parameters] + conf_str = model_proto.SerializeToString() + f.write(struct.pack('q', len(conf_str))) + f.write(conf_str) + for pname in param_names: + params.serialize(pname, f) + + print 'Generate %s success!' % (output_file) diff --git a/python/paddle/v2/framework/layers.py b/python/paddle/v2/framework/layers.py index 471bd80096f76aa4172929b4d653cad1c6380025..4bb763e6d9be39f8f1cc3521767c4f46537db7d4 100644 --- a/python/paddle/v2/framework/layers.py +++ b/python/paddle/v2/framework/layers.py @@ -284,9 +284,9 @@ def pool2d(input, inputs={"X": input}, outputs={"Out": pool_out}, attrs={ - "pooling_type": pool_type, + "poolingType": pool_type, "ksize": pool_size, - "global_pooling": global_pooling, + "globalPooling": global_pooling, "strides": pool_stride, "paddings": pool_padding }) diff --git a/python/paddle/v2/framework/tests/test_auc_op.py b/python/paddle/v2/framework/tests/test_auc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..65f679cfccccae41b8924bc68833c1703dd3671d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_auc_op.py @@ -0,0 +1,67 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAucOp(OpTest): + def setUp(self): + self.op_type = "auc" + pred = np.random.random((128)).astype("float32") + labels = np.random.randint(0, 2, (128, )) + num_thresholds = 200 + self.inputs = {'Inference': pred, 'Label': labels} + self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} + # NOTE: sklearn use a different way to generate thresholds + # which will cause the result differs slightly: + # from sklearn.metrics import roc_curve, auc + # fpr, tpr, thresholds = roc_curve(labels, pred) + # auc_value = auc(fpr, tpr) + # we caculate AUC again using numpy for testing + kepsilon = 1e-7 # to account for floating point imprecisions + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] + + # caculate TP, FN, TN, FP count + tp_list = np.ndarray((num_thresholds, )) + fn_list = np.ndarray((num_thresholds, )) + tn_list = np.ndarray((num_thresholds, )) + fp_list = np.ndarray((num_thresholds, )) + for idx_thresh, thresh in enumerate(thresholds): + tp, fn, tn, fp = 0, 0, 0, 0 + for i, lbl in enumerate(labels): + if lbl: + if pred[i] >= thresh: + tp += 1 + else: + fn += 1 + else: + if pred[i] >= thresh: + fp += 1 + else: + tn += 1 + tp_list[idx_thresh] = tp + fn_list[idx_thresh] = fn + tn_list[idx_thresh] = tn + fp_list[idx_thresh] = fp + + epsilon = 1e-6 + tpr = (tp_list.astype("float32") + epsilon) / ( + tp_list + fn_list + epsilon) + fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) + rec = (tp_list.astype("float32") + epsilon) / ( + tp_list + fp_list + epsilon) + + x = fpr[:num_thresholds - 1] - fpr[1:] + y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 + auc_value = np.sum(x * y) + + self.outputs = {'AUC': auc_value} + + def test_check_output(self): + self.check_output() + + +# TODO(typhoonzero): add this back till we fix it +#if __name__ == "__main__": +# unittest.main() diff --git a/python/paddle/v2/framework/tests/test_huber_loss_op.py b/python/paddle/v2/framework/tests/test_huber_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..003e7d7ed7ccdfc48b0aa8db0a6765b5c93e7c14 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_huber_loss_op.py @@ -0,0 +1,48 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def huber_loss_forward(val, delta): + abs_val = abs(val) + if abs_val <= delta: + return 0.5 * val * val + else: + return delta * (abs_val - 0.5 * delta) + + +class TestHuberLossOp(OpTest): + def setUp(self): + self.op_type = 'huber_loss' + samples_num = 64 + delta = 1.0 + self.inputs = { + 'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), + 'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), + } + residual = self.inputs['Y'] - self.inputs['X'] + loss = np.vectorize(huber_loss_forward)(residual, delta) + self.attrs = {'delta': delta} + self.outputs = { + 'Residual': residual, + 'Out': loss.reshape((samples_num, 1)) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual')) + + +# TODO(typhoonzero): should add this back till we fix it +#if __name__ == '__main__': +# unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pool2d_op.py b/python/paddle/v2/framework/tests/test_pool2d_op.py index 059b65e201efd30ba220a5951fac708a06b23663..f04de8133ad3b747d03500a1498b1516c21479b8 100644 --- a/python/paddle/v2/framework/tests/test_pool2d_op.py +++ b/python/paddle/v2/framework/tests/test_pool2d_op.py @@ -46,7 +46,9 @@ def avg_pool2D_forward_naive(x, ksize, strides, paddings=[0, 0], global_pool=0): class TestPool2d_Op(OpTest): def setUp(self): - self.initTestCase() + self.init_test_case() + self.init_op_type() + self.init_pool_type() input = np.random.random(self.shape).astype("float32") output = self.pool2D_forward_naive(input, self.ksize, self.strides, self.paddings, self.global_pool) @@ -56,8 +58,8 @@ class TestPool2d_Op(OpTest): 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'pooling_type': self.pool_type, - 'global_pooling': self.global_pool, + 'poolingType': self.pool_type, + 'globalPooling': self.global_pool, } self.outputs = {'Out': output.astype('float32')} @@ -69,76 +71,197 @@ class TestPool2d_Op(OpTest): if self.pool_type != "max": self.check_grad(set(['X']), 'Out', max_relative_error=0.07) - def initTestCase(self): + def init_test_case(self): self.global_pool = True - self.op_type = "pool2d" - self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 5, 5] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "avg" + class TestCase1(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False - self.op_type = "pool2d" - self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "avg" + class TestCase2(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False - self.op_type = "pool2d" - self.pool_type = "avg" self.pool2D_forward_naive = avg_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [1, 1] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "avg" + class TestCase3(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = True - self.op_type = "pool2d" - self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 5, 5] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "max" + class TestCase4(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False - self.op_type = "pool2d" - self.pool_type = "max" self.pool2D_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [0, 0] + def init_op_type(self): + self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "max" + class TestCase5(TestPool2d_Op): - def initTestCase(self): + def init_test_case(self): self.global_pool = False + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + def init_op_type(self): self.op_type = "pool2d" + + def init_pool_type(self): + self.pool_type = "max" + + +#--------------------test pool2d_cudnn-------------------- +class TestCaseCudnn1(TestPool2d_Op): + def init_test_case(self): + self.global_pool = True + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "avg" + + +class TestCaseCudnn2(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "avg" + + +class TestCaseCudnn3(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [1, 1] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "avg" + + +class TestCaseCudnn4(TestPool2d_Op): + def init_test_case(self): + self.global_pool = True + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 5, 5] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "max" + + +class TestCaseCudnn5(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False + self.pool2D_forward_naive = max_pool2D_forward_naive + self.shape = [2, 3, 7, 7] + self.ksize = [3, 3] + self.strides = [1, 1] + self.paddings = [0, 0] + + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): self.pool_type = "max" + + +class TestCaseCudnn6(TestPool2d_Op): + def init_test_case(self): + self.global_pool = False self.pool2D_forward_naive = max_pool2D_forward_naive self.shape = [2, 3, 7, 7] self.ksize = [3, 3] self.strides = [1, 1] self.paddings = [1, 1] + def init_op_type(self): + self.op_type = "pool2d_cudnn" + + def init_pool_type(self): + self.pool_type = "max" + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/framework/tests/test_pool3d_op.py b/python/paddle/v2/framework/tests/test_pool3d_op.py index abb4d4e68f532c3bf4224ca30bdd35660361f833..d62fbee9746c5524cb8c428df584d2b76cf67bc9 100644 --- a/python/paddle/v2/framework/tests/test_pool3d_op.py +++ b/python/paddle/v2/framework/tests/test_pool3d_op.py @@ -64,8 +64,8 @@ class TestPool3d_Op(OpTest): 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'pooling_type': self.pool_type, - 'global_pooling': self.global_pool, + 'poolingType': self.pool_type, + 'globalPooling': self.global_pool, } self.outputs = {'Out': output.astype('float32')} diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index b78f9bba05c5af38806f6cabb0e53379f8aa0526..f0f8aa6089c74d31702a6a5d37362099205d96b2 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest): 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'global_pooling': self.global_pool, + 'globalPooling': self.global_pool, } self.inputs = {'X': input} diff --git a/python/paddle/v2/framework/tests/test_seq_pool.py b/python/paddle/v2/framework/tests/test_seq_pool.py index 0ebf78bf8f02b4b2e5935e3177373b2d3ded7818..56602c57e6b63b71d6b089e774a876ad6164040e 100644 --- a/python/paddle/v2/framework/tests/test_seq_pool.py +++ b/python/paddle/v2/framework/tests/test_seq_pool.py @@ -22,18 +22,17 @@ class TestSeqAvgPool(OpTest): out = np.zeros((4, 23)).astype('float32') self.outputs = {'Out': out} + return x, lod, out - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.AVERAGE} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x.mean(axis=0) def setUp(self): - self.set_data() - self.compute() + x, lod, out = self.set_data() + self.compute(x, lod, out) def test_check_output(self): self.check_output() @@ -52,41 +51,34 @@ class TestSeqAvgPool2D(TestSeqAvgPool): out = np.zeros((4, 3, 17)).astype('float32') self.outputs = {'Out': out} + return x, lod, out - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.AVERAGE} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.mean(axis=0), (3, 17)) class TestSeqSumPool(TestSeqAvgPool): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.SUM} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x.sum(axis=0) class TestSeqSumPool2D(TestSeqAvgPool2D): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.SUM} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x.sum(axis=0), (3, 17)) class TestSeqSqrtPool(TestSeqAvgPool): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.SQRT} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] len = lod[0][i + 1] - lod[0][i] @@ -94,10 +86,8 @@ class TestSeqSqrtPool(TestSeqAvgPool): class TestSeqSqrtPool2D(TestSeqAvgPool2D): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.SQRT} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) len = lod[0][i + 1] - lod[0][i] @@ -107,41 +97,57 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D): self.check_grad(["X"], "Out", max_relative_error=0.06) +class TestSeqMaxPool(TestSeqAvgPool): + def compute(self, x, lod, out): + self.attrs = {'strategy': SeqPoolType.MAX} + for i in range(4): + sub_x = x[lod[0][i]:lod[0][i + 1], :] + out[i] = np.amax(sub_x, axis=0) + + def test_check_grad(self): + # Remove MaxPool2D from gradient check to confirm the success of CI. + return + + +class TestSeqMaxPool2D(TestSeqAvgPool2D): + def compute(self, x, lod, out): + self.attrs = {'strategy': SeqPoolType.MAX} + for i in range(4): + sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) + out[i] = np.reshape(np.amax(sub_x, axis=0), (3, 17)) + + def test_check_grad(self): + # Remove MaxPool2D from gradient check to confirm the success of CI. + return + + class TestSeqLastPool(TestSeqAvgPool): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.LAST} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x[-1, :] class TestSeqLastPool2D(TestSeqAvgPool2D): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.LAST} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x[-1, :], (3, 17)) class TestSeqFirstPool(TestSeqAvgPool): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.FIRST} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = x[lod[0][i]:lod[0][i + 1], :] out[i] = sub_x[0, :] class TestSeqFirstPool2D(TestSeqAvgPool2D): - def compute(self): + def compute(self, x, lod, out): self.attrs = {'strategy': SeqPoolType.FIRST} - x, lod = self.inputs['X'] - out = self.outputs['Out'] for i in range(4): sub_x = np.reshape(x[lod[0][i]:lod[0][i + 1], :], (-1, 3 * 17)) out[i] = np.reshape(sub_x[0, :], (3, 17)) diff --git a/python/paddle/v2/reader/creator.py b/python/paddle/v2/reader/creator.py index 97e844b92c77a7c58105dc5df2b4092fa5571d6f..421f6c933d7032e4103f504fc509e2d5c89149b2 100644 --- a/python/paddle/v2/reader/creator.py +++ b/python/paddle/v2/reader/creator.py @@ -61,7 +61,7 @@ def recordio(paths, buf_size=100): """ Creates a data reader from given RecordIO file paths separated by ",", glob pattern is supported. - :path: path of recordio files. + :path: path of recordio files, can be a string or a string list. :returns: data reader of recordio files. """ @@ -92,7 +92,7 @@ def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64): """ Create a data reader that yield a record one by one from the paths: - :path: path of recordio files. + :paths: path of recordio files, can be a string or a string list. :etcd_endpoints: the endpoints for etcd cluster :returns: data reader of recordio files. @@ -107,7 +107,12 @@ def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64): import cPickle as pickle import paddle.v2.master as master c = master.client(etcd_endpoints, timeout_sec, buf_size) - c.set_dataset(paths) + + if isinstance(paths, basestring): + path = [paths] + else: + path = paths + c.set_dataset(path) def reader(): global pass_num