(convBwdFilterAlgo)));
-
CHECK_CUDNN(dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
t_resource.cudnn_handle,
bwd_filter_src_desc,
@@ -603,7 +609,9 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
int padding_height,
int padding_width,
int stride_height,
- int stride_width) {
+ int stride_width,
+ int dilation_h,
+ int dilation_w) {
CHECK_NOTNULL(conv);
cudnn_convolution_descriptor hl_conv = (cudnn_convolution_descriptor)malloc(
@@ -625,18 +633,24 @@ void hl_create_convolution_descriptor(hl_convolution_descriptor* conv,
padding_width,
stride_height,
stride_width,
- 1,
- 1,
+ dilation_h,
+ dilation_w,
mode,
data_type));
#else
+ if (dilation_h > 1 || dilation_w > 1) {
+ LOG(FATAL)
+ << "Current cuDNN version does't support for dilation convolution. "
+ << "The dilation convolution requires cuDNN >= v6.0.";
+ }
+
CHECK_CUDNN(dynload::cudnnSetConvolution2dDescriptor(hl_conv->desc,
padding_height,
padding_width,
stride_height,
stride_width,
- 1,
- 1,
+ dilation_h,
+ dilation_w,
mode));
#endif
@@ -659,7 +673,9 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
int padding_height,
int padding_width,
int stride_height,
- int stride_width) {
+ int stride_width,
+ int dilation_h,
+ int dilation_w) {
CHECK_NOTNULL(conv);
CHECK_NOTNULL(image);
CHECK_NOTNULL(filter);
@@ -678,8 +694,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
padding_width,
stride_height,
stride_width,
- 1,
- 1,
+ dilation_h,
+ dilation_w,
mode,
data_type));
#else
@@ -688,8 +704,8 @@ void hl_reset_convolution_descriptor(hl_convolution_descriptor conv,
padding_width,
stride_height,
stride_width,
- 1,
- 1,
+ dilation_h,
+ dilation_w,
mode));
#endif
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 74c001b06a9e7b2279abf998604f2acf1b1168e4..c8fa3fefe5632a36d9044b4bccfd3dbb7c64dbf6 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -21,18 +21,32 @@ grad_op_builder(fengjiayi)
given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
-1. bla bla bla (yuyang)
+1. Op
+
+ when the input forward network is a Op, return its gradient Operator Immediately.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively and ensure them done. During the process, we need to collect the `OutputGradients` name.
+ when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+
+ **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+
+
+
+
+ 1. shared variable in two operators.
+
+
+
+ Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+
+
+
- We share variable in the same scope, as a result, duplicate operator `OutputGradients` will overwirte then duplicate variable.
+ 2. replace shared variable gradient with `Add` Operator
- ![./images/duplicate_op]()
+
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator instead.
-![./images/duplicate_op2]()
- Then collect the sub graph OutputGradients/InputGradients as the NetOp's and return it.
+ Then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
index d00d408ab8d2b1e6951e9fe58981ba85b9077908..9f29b97466910f1daf88e3ca86f92d10661462c5 100644
--- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
+++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp
@@ -1344,7 +1344,7 @@ void RecurrentGradientMachine::fillGenOutputs() {
CHECK(!finalPaths_[i].empty());
Path& path = finalPaths_[i][0];
generator_.ids.insert(
- generator_.ids.begin(), path.ids.begin(), path.ids.end());
+ generator_.ids.end(), path.ids.begin(), path.ids.end());
starts[i + 1] = starts[i] + path.ids.size();
}
}
diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp
index e161d89c38a290000a2cbdb2905e56901ae4c144..a5328ef8343e1050352fc48530e041fb6ce12a8b 100644
--- a/paddle/gserver/layers/ConvBaseLayer.cpp
+++ b/paddle/gserver/layers/ConvBaseLayer.cpp
@@ -32,9 +32,11 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
const ConvConfig& conf = inputConfig.conv_conf();
padding_.push_back(conf.padding());
stride_.push_back(conf.stride());
+ dilation_.push_back(conf.dilation());
filterSize_.push_back(conf.filter_size());
paddingY_.push_back(conf.padding_y());
strideY_.push_back(conf.stride_y());
+ dilationY_.push_back(conf.dilation_y());
filterSizeY_.push_back(conf.filter_size_y());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back());
channels_.push_back(conf.channels());
@@ -89,7 +91,11 @@ size_t ConvBaseLayer::calOutputSize() {
size_t layerSize = 0;
auto setLayerSize = [&](IntV& inH, IntV& inW, IntV& outH, IntV& outW) {
+ size_t filterSizeY;
+ size_t filterSize;
for (size_t i = 0; i < inputLayers_.size(); i++) {
+ filterSizeY = (filterSizeY_[i] - 1) * dilationY_[i] + 1;
+ filterSize = (filterSize_[i] - 1) * dilation_[i] + 1;
inH.push_back(inputLayers_[i]->getOutput().getFrameHeight());
inW.push_back(inputLayers_[i]->getOutput().getFrameWidth());
const ConvConfig& conf = config_.inputs(i).conv_conf();
@@ -98,17 +104,17 @@ size_t ConvBaseLayer::calOutputSize() {
inH[i] = conf.has_output_y() ? conf.output_y() : conf.output_x();
if (inW[i] == 0) inW[i] = conf.output_x();
outH.push_back(imageSize(
- inH[i], filterSizeY_[i], paddingY_[i], strideY_[i], caffeMode_));
- outW.push_back(imageSize(
- inW[i], filterSize_[i], padding_[i], stride_[i], caffeMode_));
+ inH[i], filterSizeY, paddingY_[i], strideY_[i], caffeMode_));
+ outW.push_back(
+ imageSize(inW[i], filterSize, padding_[i], stride_[i], caffeMode_));
} else {
if (inH[i] == 0)
inH[i] = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
if (inW[i] == 0) inW[i] = conf.img_size();
outH.push_back(outputSize(
- inH[i], filterSizeY_[i], paddingY_[i], strideY_[i], caffeMode_));
+ inH[i], filterSizeY, paddingY_[i], strideY_[i], caffeMode_));
outW.push_back(outputSize(
- inW[i], filterSize_[i], padding_[i], stride_[i], caffeMode_));
+ inW[i], filterSize, padding_[i], stride_[i], caffeMode_));
}
CHECK_EQ(outH[i], outH[0]);
CHECK_EQ(outW[i], outW[0]);
diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h
index e9d15d94f806a5d2e6f11cbbfc29e291dfe8538f..223bce8e296d748c8e17eb105aa67e8a1c1219b6 100644
--- a/paddle/gserver/layers/ConvBaseLayer.h
+++ b/paddle/gserver/layers/ConvBaseLayer.h
@@ -40,6 +40,10 @@ protected:
IntV stride_;
/// The y dimension of the stride.
IntV strideY_;
+ /// The x dimension of the dilation.
+ IntV dilation_;
+ /// The y dimension of the dilation.
+ IntV dilationY_;
/// The x dimension of a filter kernel.
IntV filterSize_;
/// The y dimension of a filter kernel.
diff --git a/paddle/gserver/layers/ConvBaseOperator.cpp b/paddle/gserver/layers/ConvBaseOperator.cpp
index 5c231986292d2cd26ee30ccc122142fccd5b4949..5469c41c87468001232f7bae0d5b6bf26693b9e0 100644
--- a/paddle/gserver/layers/ConvBaseOperator.cpp
+++ b/paddle/gserver/layers/ConvBaseOperator.cpp
@@ -59,7 +59,8 @@ void ConvBaseOperator::allocConvWorkSpace() {
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
- &bwdFilterLimitBytes_);
+ &bwdFilterLimitBytes_,
+ /*useDilation*/ false);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
diff --git a/paddle/gserver/layers/ConvBaseProjection.cpp b/paddle/gserver/layers/ConvBaseProjection.cpp
index eb6b0445c95a9e9a7acd5d693ecdb11a263f41fd..08f36c516cfdadd42e9333c1c5a7a247df1f263e 100644
--- a/paddle/gserver/layers/ConvBaseProjection.cpp
+++ b/paddle/gserver/layers/ConvBaseProjection.cpp
@@ -41,6 +41,11 @@ void ConvBaseProjection::getConvParams() {
strideH_ = conf.stride_y();
strideW_ = conf.stride();
+ dilationH_ = conf.dilation_y();
+ dilationW_ = conf.dilation();
+ CHECK_GT(dilationH_, 0);
+ CHECK_GT(dilationW_, 0);
+
filterH_ = conf.filter_size_y();
filterW_ = conf.filter_size();
@@ -77,7 +82,9 @@ void ConvBaseProjection::initCudnn() {
paddingH_,
paddingW_,
strideH_,
- strideW_);
+ strideW_,
+ dilationH_,
+ dilationW_);
// initialize all to default algorithms
fwdAlgo_ = 0;
@@ -131,7 +138,9 @@ void ConvBaseProjection::reshapeTensorDesc(int batchSize) {
paddingH_,
paddingW_,
strideH_,
- strideW_);
+ strideW_,
+ dilationH_,
+ dilationW_);
}
void ConvBaseProjection::reshape(int batchSize) {
@@ -140,6 +149,10 @@ void ConvBaseProjection::reshape(int batchSize) {
CHECK_EQ(calInputSize(), in_->value->getWidth());
reshapeTensorDesc(batchSize);
+ bool useDilation = false;
+ if (dilationH_ > 1 || dilationW_ > 1) {
+ useDilation = true;
+ }
hl_conv_workspace(imageDesc_,
outputDesc_,
filterDesc_,
@@ -149,7 +162,8 @@ void ConvBaseProjection::reshape(int batchSize) {
&bwdDataAlgo_,
&bwdDataLimitBytes_,
&bwdFilterAlgo_,
- &bwdFilterLimitBytes_);
+ &bwdFilterLimitBytes_,
+ useDilation);
size_t maxWorkSpace = 0;
maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
diff --git a/paddle/gserver/layers/ConvBaseProjection.h b/paddle/gserver/layers/ConvBaseProjection.h
index e9d9f8f1b2937b3a3b7323c43ef5608ffc5f82ca..ebdb57845bb36ac607b1e4c8e02f9d20b6e82a36 100644
--- a/paddle/gserver/layers/ConvBaseProjection.h
+++ b/paddle/gserver/layers/ConvBaseProjection.h
@@ -63,6 +63,7 @@ protected:
int configChannels_, configNumFilters_;
int paddingH_, paddingW_;
int strideH_, strideW_;
+ int dilationH_, dilationW_;
int filterH_, filterW_;
/// One group offset of input data.
int inputOffset_;
diff --git a/paddle/gserver/layers/ConvProjection.cpp b/paddle/gserver/layers/ConvProjection.cpp
index 5b7ecc5560c1e7431305b34a331fe1fbc96c6b06..6f0106b713d93494ba9baa5c7afa0a6b1f167262 100644
--- a/paddle/gserver/layers/ConvProjection.cpp
+++ b/paddle/gserver/layers/ConvProjection.cpp
@@ -25,12 +25,12 @@ size_t ConvProjection::calOutputSize() {
if (imageH_ == 0) imageH_ = configImgH_;
if (imageW_ == 0) imageW_ = configImgW_;
outputH_ = outputSize(imageH_,
- filterH_,
+ (filterH_ - 1) * dilationH_ + 1,
paddingH_,
strideH_,
/* caffeMode */ true);
outputW_ = outputSize(imageW_,
- filterW_,
+ (filterW_ - 1) * dilationW_ + 1,
paddingW_,
strideW_,
/* caffeMode */ true);
diff --git a/paddle/gserver/layers/SequenceSliceLayer.cpp b/paddle/gserver/layers/SequenceSliceLayer.cpp
index 6f6577445f56c17d2c43d21a19a086c985714658..d3a83fad276a384ab3fddd5349912c56be6f3cc0 100644
--- a/paddle/gserver/layers/SequenceSliceLayer.cpp
+++ b/paddle/gserver/layers/SequenceSliceLayer.cpp
@@ -130,6 +130,8 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
CHECK(starts || ends) << "At least one of the start or end indices "
<< "should be given.";
+ bool hasSubseq = getInput(0).hasSubseq();
+
outSeqStartPos_.resize(1, 0);
outSubSeqStartPos_.resize(1, 0);
selectedRows_.clear();
@@ -151,14 +153,13 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
int seqLen = endPos - begPos + 1;
CHECK_GT(seqLen, 0U);
for (int m = begPos; m <= endPos; ++m) selectedRows_.push_back(m);
- inputSeqInfoVec_.size() > 1
+ hasSubseq
? outSubSeqStartPos_.push_back(outSubSeqStartPos_.back() + seqLen)
: outSeqStartPos_.push_back(outSeqStartPos_.back() + seqLen);
}
rowIdx++;
}
- if (inputSeqInfoVec_.size() > 1)
- outSeqStartPos_.push_back(outSubSeqStartPos_.back());
+ if (hasSubseq) outSeqStartPos_.push_back(outSubSeqStartPos_.back());
}
if (useGpu_) {
@@ -175,7 +176,7 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
output_.sequenceStartPositions->copyFrom(
outSeqStartPos_.data(), outSeqStartPos_.size(), false);
- if (inputSeqInfoVec_.size() > 1) {
+ if (hasSubseq) {
ICpuGpuVector::resizeOrCreate(
output_.subSequenceStartPositions, outSubSeqStartPos_.size(), false);
output_.subSequenceStartPositions->copyFrom(
@@ -204,10 +205,11 @@ void SequenceSliceLayer::forward(PassType passType) {
copySliceIdsToCpu();
}
- // calculate the selected row indices in a batch,
- // and build the output sequence information.
- calSelectedRows(startIdsOnCpu_ ? startIdsOnCpu_ : nullptr,
- endIdsOnCpu_ ? endIdsOnCpu_ : nullptr);
+ /*
+ * calculate the selected row indices in a batch, and build the output
+ * sequence information.
+ */
+ calSelectedRows(startIdsOnCpu_, endIdsOnCpu_);
resetOutput(selectedRows_.size(), getSize());
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index b974dc5d573884fd099b9755a7e60202e9cfeb6c..93b6e3cc5bd7a87aa854052277772904d70de802 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#ifndef PADDLE_ONLY_CPU
+#include
+#endif
#include
#include
#include
@@ -189,10 +192,16 @@ TEST(Projection, scaling) {
void testProjectionConv(size_t groups, bool isDeconv) {
const int NUM_FILTERS = 18;
const int FILTER_SIZE = 2;
- const int FILTER_SIZE_Y = 4;
+ const int FILTER_SIZE_Y = 2;
const int CHANNELS = 3;
const int IMAGE_SIZE = 16;
+#if CUDNN_VERSION >= 6000
+ const int DILATION = 2;
+#else
+ const int DILATION = 1;
+#endif
+
ProjectionConfig conf;
if (isDeconv) {
conf.set_type("convt");
@@ -209,6 +218,8 @@ void testProjectionConv(size_t groups, bool isDeconv) {
conv->set_padding_y(1);
conv->set_stride(2);
conv->set_stride_y(2);
+ conv->set_dilation(DILATION);
+ conv->set_dilation_y(DILATION);
conv->set_groups(groups);
if (isDeconv) {
conv->set_filter_channels(NUM_FILTERS / conv->groups());
@@ -217,12 +228,12 @@ void testProjectionConv(size_t groups, bool isDeconv) {
}
conv->set_img_size(IMAGE_SIZE);
int output_x = outputSize(conv->img_size(),
- conv->filter_size(),
+ (conv->filter_size() - 1) * DILATION + 1,
conv->padding(),
conv->stride(),
/* caffeMode */ true);
int output_y = outputSize(conv->img_size(),
- conv->filter_size_y(),
+ (conv->filter_size_y() - 1) * DILATION + 1,
conv->padding_y(),
conv->stride_y(),
/* caffeMode */ true);
@@ -424,27 +435,38 @@ void testConvLayer(const string& type, bool trans, bool useGpu) {
config.layerConfig.set_partial_sum(1);
config.layerConfig.set_shared_biases(true);
- config.inputDefs.push_back({INPUT_DATA, "layer_0", 384, 288});
+ int dilation = 1;
+ if (type == "cudnn_conv") {
+#if CUDNN_VERSION >= 6000
+ dilation = 2;
+#else
+ dilation = 1;
+#endif
+ }
+
+ config.inputDefs.push_back({INPUT_DATA, "layer_0", 768, 192});
LayerInputConfig* input = config.layerConfig.add_inputs();
ConvConfig* conv = input->mutable_conv_conf();
conv->set_filter_size(2);
- conv->set_filter_size_y(3);
+ conv->set_filter_size_y(2);
conv->set_channels(3);
conv->set_padding(0);
conv->set_padding_y(1);
conv->set_stride(2);
conv->set_stride_y(2);
+ conv->set_dilation(dilation);
+ conv->set_dilation_y(dilation);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
conv->set_img_size(16);
- conv->set_img_size_y(8);
+ conv->set_img_size_y(16);
conv->set_output_x(outputSize(conv->img_size(),
- conv->filter_size(),
+ (conv->filter_size() - 1) * dilation + 1,
conv->padding(),
conv->stride(),
/* caffeMode */ true));
conv->set_output_y(outputSize(conv->img_size_y(),
- conv->filter_size_y(),
+ (conv->filter_size_y() - 1) * dilation + 1,
conv->padding_y(),
conv->stride_y(),
/* caffeMode */ true));
diff --git a/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp b/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp
index d560ca650bc5b156de280a2a0d698b67eb032907..e1d4ae16176433b898ba88dd60550e44b4fe37be 100644
--- a/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp
+++ b/paddle/gserver/tests/test_SeqSliceLayerGrad.cpp
@@ -30,6 +30,8 @@ const int MAX_SEQ_NUM = 17;
const int MAX_SEQ_LEN = 23;
const int MAX_BEAM_SIZE = 13;
+const size_t SEED = (size_t)(time(NULL));
+
vector randSampling(real range, int n) {
CHECK_GE(range, n);
vector num(range);
@@ -46,7 +48,7 @@ void genSeqInfo(vector& seqStartPos, vector& subSeqStartPos) {
seqStartPos.resize(1, 0);
subSeqStartPos.resize(1, 0);
- srand((size_t)(time(NULL)));
+ srand(SEED);
int seqNum = 1 + (rand() % MAX_SEQ_NUM);
for (int i = 0; i < seqNum; ++i) {
int subSeqNum = 1 + (rand() % MAX_SEQ_NUM);
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index b56a45b6bd1e4d834a3c11da989b4a0707a24bf6..f0fd12f1b5276d033ea086c60c80616fb1be7585 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -42,10 +42,12 @@ function(op_library TARGET)
endfunction()
add_subdirectory(math)
+
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
op_library(gather_op SRCS gather_op.cc gather_op.cu)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
+op_library(scatter_op SRCS scatter_op.cc scatter_op.cu)
cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
@@ -67,7 +69,7 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
-op_library(uniform_random_op
- SRCS uniform_random_op.cc uniform_random_op.cu)
+op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu)
+op_library(lookup_table_op SRCS lookup_table_op.cc lookup_table_op.cu)
op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op)
diff --git a/paddle/operators/fill_zeros_like_op.h b/paddle/operators/fill_zeros_like_op.h
index fd380ca8514b0ac50f39613368a4836bd485668b..969998ce2eae02b8ad057c6259703e51559bf98a 100644
--- a/paddle/operators/fill_zeros_like_op.h
+++ b/paddle/operators/fill_zeros_like_op.h
@@ -26,7 +26,7 @@ class FillZerosLikeKernel : public framework::OpKernel {
auto* output = context.Output("Dst");
output->mutable_data(context.GetPlace());
auto t = framework::EigenVector::Flatten(*output);
- t.device(context.GetEigenDevice()) = t.constant(T(0));
+ t.device(context.GetEigenDevice()) = t.constant(static_cast(0));
}
};
diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..94d40890a765413e88a35a6ad995ca97ac84dcda
--- /dev/null
+++ b/paddle/operators/lookup_table_op.cc
@@ -0,0 +1,72 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/operators/lookup_table_op.h"
+
+namespace paddle {
+namespace operators {
+
+class LookupTableOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &context) const override {
+ auto table_t = context.Input("W");
+ auto ids_t = context.Input("Ids");
+ auto output_t = context.Output("Out");
+
+ output_t->Resize({ids_t->dims()[0], table_t->dims()[1]});
+ }
+};
+
+class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ LookupTableOpMaker(framework::OpProto *proto,
+ framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("W",
+ "An input represents embedding tensors,"
+ " which is a learnable parameter.");
+ AddInput("Ids",
+ "An input with type int32 or int64"
+ "contains the ids to be looked up in W.");
+ AddOutput("Out", "The lookup results, which have the same type with W.");
+ AddComment(
+ "This operator is used to perform lookups on the parameter W,"
+ "then concatenated into a dense tensor.");
+ }
+};
+
+class LookupTableOpGrad : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &context) const override {
+ auto table = context.Input("W");
+ auto d_table = context.Output(framework::GradVarName("W"));
+ d_table->Resize(table->dims());
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
+ lookup_table_grad, ops::LookupTableOpGrad);
+
+REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel);
+REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel);
diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..27eee3436af8107cef2aa3577ea238be49edf1af
--- /dev/null
+++ b/paddle/operators/lookup_table_op.cu
@@ -0,0 +1,116 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/assert.h"
+#include "paddle/platform/cuda_helper.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+__global__ void LookupTable(T* output, const T* table, const int32_t* ids,
+ const int N, const int K, const int D) {
+ int idx = threadIdx.x;
+ int idy = blockIdx.x + threadIdx.y * GridDimX;
+
+ while (idy < K) {
+ int id = ids[idy];
+ PADDLE_ASSERT(id >= 0);
+ PADDLE_ASSERT(id < N);
+ T* out = output + idy * D;
+ const T* tab = table + id * D;
+ for (int i = idx; i < D; i += BlockDimX) {
+ out[i] = tab[i];
+ }
+ idy += BlockDimY * GridDimX;
+ }
+}
+
+template
+__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
+ const int N, const int K, const int D) {
+ int idx = threadIdx.x;
+ int idy = blockIdx.x + threadIdx.y * GridDimX;
+
+ while (idy < K) {
+ int id = ids[idy];
+ PADDLE_ASSERT(id >= 0);
+ PADDLE_ASSERT(id < N);
+ const T* out = output + idy * D;
+ T* tab = table + id * D;
+ for (int i = idx; i < D; i += BlockDimX) {
+ paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
+ }
+ idy += BlockDimY * GridDimX;
+ }
+}
+
+template
+class LookupTableCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto table_t = context.Input("W");
+ auto ids_t = context.Input("Ids");
+ auto output_t = context.Output("Out");
+
+ size_t N = table_t->dims()[0];
+ size_t D = table_t->dims()[1];
+ size_t K = product(ids_t->dims());
+ auto ids = ids_t->data();
+ auto table = table_t->data();
+ auto output = output_t->mutable_data(context.GetPlace());
+
+ dim3 threads(128, 8);
+ dim3 grids(8, 1);
+ LookupTable<<>>(output, table, ids, N, K, D);
+ }
+};
+
+template
+class LookupTableGradCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto ids_t = context.Input("Ids");
+ auto d_output_t = context.Input(framework::GradVarName("Out"));
+ auto d_table_t = context.Output(framework::GradVarName("W"));
+
+ int N = d_table_t->dims()[0];
+ int D = d_table_t->dims()[1];
+ int K = product(ids_t->dims());
+ const int32_t* ids = ids_t->data();
+ const T* d_output = d_output_t->data();
+ T* d_table = d_table_t->mutable_data(context.GetPlace());
+
+ auto t = framework::EigenVector::Flatten(*d_table_t);
+ t.device(context.GetEigenDevice()) =
+ t.constant(static_cast(0));
+
+ dim3 threads(128, 8);
+ dim3 grids(8, 1);
+ LookupTableGrad<<>>(d_table, d_output, ids, N,
+ K, D);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel);
+REGISTER_OP_GPU_KERNEL(lookup_table_grad,
+ ops::LookupTableGradCUDAKernel);
diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..4da8079b91624c3510cae89fd599a7035a4c7477
--- /dev/null
+++ b/paddle/operators/lookup_table_op.h
@@ -0,0 +1,75 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+class LookupTableKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto table_t = context.Input("W"); // float tensor
+ auto ids_t = context.Input("Ids"); // int tensor
+ auto output_t = context.Output("Out"); // float tensor
+
+ size_t N = table_t->dims()[0];
+ size_t D = table_t->dims()[1];
+ auto ids = ids_t->data();
+ auto table = table_t->data();
+ auto output = output_t->mutable_data(context.GetPlace());
+ for (size_t i = 0; i < product(ids_t->dims()); ++i) {
+ PADDLE_ENFORCE_LT(ids[i], N);
+ PADDLE_ENFORCE_GE(ids[i], 0);
+ memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
+ }
+ }
+};
+
+template
+class LookupTableGradKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto ids_t = context.Input("Ids");
+ auto d_output_t = context.Input(framework::GradVarName("Out"));
+ auto d_table_t = context.Output(framework::GradVarName("W"));
+
+ size_t N = d_table_t->dims()[0];
+ size_t D = d_table_t->dims()[1];
+ auto ids = ids_t->data();
+ const T* d_output = d_output_t->data();
+ T* d_table = d_table_t->mutable_data(context.GetPlace());
+
+ auto t = framework::EigenVector::Flatten(*d_table_t);
+ t.device(context.GetEigenDevice()) =
+ t.constant(static_cast(0));
+
+ for (size_t i = 0; i < product(ids_t->dims()); ++i) {
+ PADDLE_ENFORCE_LT(ids[i], N);
+ PADDLE_ENFORCE_GE(ids[i], 0);
+ for (size_t j = 0; j < D; ++j) {
+ d_table[ids[i] * D + j] += d_output[i * D + j];
+ }
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..f901edefa22dc9a252e87116df756d04767a7162
--- /dev/null
+++ b/paddle/operators/scatter_op.cc
@@ -0,0 +1,86 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/scatter_op.h"
+#include "paddle/framework/ddim.h"
+
+namespace paddle {
+namespace operators {
+
+class ScatterOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_EQ(ctx.Input("Index")->dims().size(), 1,
+ "Update Index should be 1-D.");
+ PADDLE_ENFORCE_EQ(ctx.Input("Ref")->dims().size(),
+ ctx.Input("Updates")->dims().size(),
+ "Reference and Updates should have the same shape size");
+ PADDLE_ENFORCE_EQ(ctx.Input("Updates")->dims()[0],
+ ctx.Input("Index")->dims()[0],
+ "Updates and Index should have same batch-size.");
+ framework::DDim data_dim(ctx.Input("Updates")->dims());
+ for (int i = 1; i < data_dim.size(); ++i)
+ PADDLE_ENFORCE_EQ(data_dim[i], ctx.Input("Updates")->dims()[i]);
+ ctx.Output("Out")->Resize(ctx.Input("Ref")->dims());
+ }
+};
+
+class ScatterGradOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto *dUpdates = ctx.Output(framework::GradVarName("Updates"));
+ auto *Updates = ctx.Input("Updates");
+ auto *dRef = ctx.Output(framework::GradVarName("Ref"));
+ auto *Ref = ctx.Input("Ref");
+
+ dRef->Resize(Ref->dims());
+ dUpdates->Resize(Updates->dims());
+ }
+};
+
+class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ ScatterOpMaker(framework::OpProto *proto,
+ framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("Ref", "The source input of scatter op");
+ AddInput("Index",
+ "The index input of scatter op where Ref will be updated");
+ AddInput("Updates", "The updated value of updates op");
+ AddOutput("Out", "The output of add op");
+ AddComment(R"DOC(
+Scatter Operator by selecting from the first axis,
+
+Out = Ref
+Out[Index] = Ref[Index] + Updates
+)DOC");
+ }
+};
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(scatter, ops::ScatterOp, ops::ScatterOpMaker, scatter_grad,
+ ops::ScatterGradOp);
+REGISTER_OP_CPU_KERNEL(scatter,
+ ops::ScatterOpKernel);
+REGISTER_OP_CPU_KERNEL(
+ scatter_grad,
+ ops::ScatterGradientOpKernel);
diff --git a/paddle/operators/scatter_op.cu b/paddle/operators/scatter_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..6716b478833ff3adb6112cdb1ee25b7f1744ea1f
--- /dev/null
+++ b/paddle/operators/scatter_op.cu
@@ -0,0 +1,20 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/scatter_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(scatter,
+ ops::ScatterOpKernel);
diff --git a/paddle/operators/scatter_op.h b/paddle/operators/scatter_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..e9595638a86a4a4536ddad4e6f20fd80a54b1608
--- /dev/null
+++ b/paddle/operators/scatter_op.h
@@ -0,0 +1,60 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "gather.h"
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+#include "scatter.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+class ScatterOpKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ auto *Ref = ctx.Input("Ref");
+ auto *Index = ctx.Input("Index");
+ auto *Updates = ctx.Input("Updates");
+ auto *Out = ctx.Output("Out");
+
+ // In place output: Out = Ref, Out[Index] += Updates
+ Out->ShareDataWith(*Ref);
+ // Apply ScatterUpdate: Out[index] += Updates[:]
+ ScatterUpdate(ctx.GetPlace(), Updates, Index, Out);
+ }
+};
+
+template
+class ScatterGradientOpKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext &ctx) const override {
+ auto *dRef = ctx.Output(framework::GradVarName("Ref"));
+ auto *dUpdates = ctx.Output(framework::GradVarName("Updates"));
+ auto *Index = ctx.Input("Index");
+ auto *dOut = ctx.Input(framework::GradVarName("Out"));
+
+ // In place gradient: dRef = dO
+ dRef->ShareDataWith(*dOut);
+ dUpdates->mutable_data(ctx.GetPlace());
+ // Gradient by Gather: dUpdates += dO[Index]
+ Gather(ctx.GetPlace(), dOut, Index, dUpdates);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/platform/cuda_helper.h b/paddle/platform/cuda_helper.h
new file mode 100644
index 0000000000000000000000000000000000000000..6feec0d7f8bd5d32d9e5eedee962fcbeff655f1c
--- /dev/null
+++ b/paddle/platform/cuda_helper.h
@@ -0,0 +1,51 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+
+namespace paddle {
+namespace platform {
+
+#define CUDA_ATOMIC_WRAPPER(op, T) \
+ __device__ __forceinline__ T CudaAtomic##op(T* address, const T val)
+
+#define USE_CUDA_ATOMIC(op, T) \
+ CUDA_ATOMIC_WRAPPER(op, T) { return atomic##op(address, val); }
+
+// For atomicAdd.
+USE_CUDA_ATOMIC(Add, float);
+
+#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600
+USE_CUDA_ATOMIC(Add, double);
+#else
+CUDA_ATOMIC_WRAPPER(Add, double) {
+ unsigned long long int* address_as_ull =
+ reinterpret_cast(address);
+ unsigned long long int old = *address_as_ull, assumed;
+
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_ull, assumed,
+ __double_as_longlong(val + __longlong_as_double(assumed)));
+
+ // Note: uses integer comparison to avoid hang in case of NaN
+ } while (assumed != old);
+
+ return __longlong_as_double(old);
+}
+#endif
+
+} // namespace platform
+} // namespace paddle
diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt
index 40db811767a9c273f073f4715e6ddfbf05887730..37e186a408ff5f560b5878e3e51ea81ca5810bc7 100644
--- a/paddle/pybind/CMakeLists.txt
+++ b/paddle/pybind/CMakeLists.txt
@@ -4,6 +4,7 @@ cc_library(paddle_pybind SHARED
DEPS pybind python backward
sgd_op
gather_op
+ scatter_op
add_op
mul_op
rowwise_add_op
@@ -15,6 +16,7 @@ cc_library(paddle_pybind SHARED
uniform_random_op
gaussian_random_op
fill_zeros_like_op
+ lookup_table_op
scale_op
minus_op)
endif(WITH_PYTHON)
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index 27b98e77db80505f7498deb75164e184b900262b..3bc150ccb7af2885439cc2344aa0db9ba3b1ca03 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -42,10 +42,12 @@ USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);
+USE_OP(lookup_table);
USE_OP(scale);
USE_OP_ITSELF(identity);
USE_OP(minus);
USE_CPU_ONLY_OP(gather);
+USE_CPU_ONLY_OP(scatter);
namespace paddle {
namespace framework {
diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh
index 2941662f349baf57d1fe8188e88ce21d5de07750..17986420220fec173bbf3ecff240d4c504f8adbd 100644
--- a/paddle/scripts/docker/build.sh
+++ b/paddle/scripts/docker/build.sh
@@ -38,7 +38,7 @@ Configuring cmake in /paddle/build ...
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON}
-DCUDNN_ROOT=/usr/
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
- -DWITH_TESTING=${WITH_TESTING:-OFF}
+ -DWITH_TESTING=${WITH_TESTING:-ON}
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
========================================
EOF
@@ -56,19 +56,18 @@ cmake .. \
-DWITH_C_API=${WITH_C_API:-OFF} \
-DWITH_PYTHON=${WITH_PYTHON:-ON} \
-DCUDNN_ROOT=/usr/ \
- -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF} \
- -DWITH_TESTING=${WITH_TESTING:-OFF} \
+ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \
+ -DWITH_TESTING=${WITH_TESTING:-ON} \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
cat < 1 or dilation_y > 1:
+ assert layer_type in ["cudnn_conv", "cudnn_convt"]
if trans:
assert layer_type in ["exconvt", "cudnn_convt"]
else:
@@ -2486,11 +2502,13 @@ def img_conv_layer(input,
conv=Conv(
filter_size=filter_size,
padding=padding,
+ dilation=dilation,
stride=stride,
channels=num_channels,
groups=groups,
filter_size_y=filter_size_y,
padding_y=padding_y,
+ dilation_y=dilation_y,
stride_y=stride_y),
**param_attr.attr),
active_type=act.name,
@@ -2591,15 +2609,15 @@ def img_pool_layer(input,
assert input.num_filters is not None
num_channels = input.num_filters
- assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling,
- CudnnMaxPooling], \
- "only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported"
-
if pool_type is None:
pool_type = MaxPooling()
elif isinstance(pool_type, AvgPooling):
pool_type.name = 'avg'
+ assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling,
+ CudnnMaxPooling], \
+ "only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported"
+
type_name = pool_type.name + '-projection' \
if (
isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \
diff --git a/python/paddle/trainer_config_helpers/tests/configs/img_layers.py b/python/paddle/trainer_config_helpers/tests/configs/img_layers.py
index 9fda16a5407a1fe0af8c5986023a8368e5b87222..01d31ef3fad827bfd103ee00f4ddd1bde14e0f82 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/img_layers.py
+++ b/python/paddle/trainer_config_helpers/tests/configs/img_layers.py
@@ -12,6 +12,7 @@ img_conv = img_conv_layer(
num_filters=64,
filter_size=(32, 32),
padding=(1, 1),
+ dilation=(1, 1),
stride=(1, 1),
act=LinearActivation())
img_bn = batch_norm_layer(input=img_conv, act=ReluActivation())
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index 2849ee7c8d0404432fcf6156552f40657d094983..661ebd89648feec77367c278e5f045b8238e1dc1 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -14,6 +14,7 @@ py_test(test_sigmoid_op SRCS test_sigmoid_op.py)
py_test(test_softmax_op SRCS test_softmax_op.py)
py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
+py_test(test_scatter_op SRCS test_scatter_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
@@ -28,5 +29,6 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
py_test(test_recurrent_op SRCS test_recurrent_op.py)
py_test(test_sgd_op SRCS test_sgd_op.py)
py_test(test_gradient_checker SRCS test_gradient_checker.py)
+py_test(test_lookup_table SRCS test_lookup_table.py)
py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
py_test(mnist SRCS mnist.py)
diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py
index c22c6f8831b2551d9a83747bc0d15789a78a101e..9a7a7fbf5e63d4e433576f8e980c41c72fa26cab 100644
--- a/python/paddle/v2/framework/tests/gradient_checker.py
+++ b/python/paddle/v2/framework/tests/gradient_checker.py
@@ -23,12 +23,17 @@ def grad_var_name(var_name):
return var_name + "@GRAD"
+def empty_var_name():
+ return "@EMPTY@"
+
+
def get_numeric_gradient(op,
input_values,
output_name,
input_to_check,
delta=0.005,
- local_scope=None):
+ local_scope=None,
+ in_place=False):
"""
Get Numeric Gradient for an operator's input.
@@ -77,6 +82,11 @@ def get_numeric_gradient(op,
def product(dim):
return reduce(lambda a, b: a * b, dim, 1)
+ def restore_inputs():
+ for var_name in input_values:
+ tensor_ = local_scope.find_var(var_name).get_tensor()
+ tensor_.set(numpy.copy(input_values[var_name]), core.CPUPlace())
+
# get the input tensor that we want to get it's numeric gradient.
tensor_to_check = local_scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims())
@@ -86,6 +96,8 @@ def get_numeric_gradient(op,
# we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size):
+ if in_place:
+ restore_inputs()
# get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i)
@@ -95,6 +107,8 @@ def get_numeric_gradient(op,
y_pos = get_output()
# plus delta to this element, run op and get the sum of the result tensor.
+ if in_place:
+ restore_inputs()
x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg)
y_neg = get_output()
@@ -176,7 +190,7 @@ class GradientChecker(unittest.TestCase):
]
return outs
- def compare_grad(self, forward_op, input_value):
+ def compare_grad(self, forward_op, input_value, no_grad_set=None):
""" Compare the input gradients between CPU and GPU for the given forward
operator.
@@ -184,15 +198,20 @@ class GradientChecker(unittest.TestCase):
:type forward_op: Operator
:param input_value: input values.
:type input_value: dict{string:numpy.array}
+ :param no_grad_set: the set of variables names without gradients.
+ :type no_grad_set: a set of string
:raises: AssertionError, there is different gradient value.
"""
- backward_op = core.Operator.backward(forward_op, set())
+ if no_grad_set is None:
+ no_grad_set = set()
+ backward_op = core.Operator.backward(forward_op, no_grad_set)
# return if not compile with GPU or not implementing GPU kernel
if not (core.is_compile_gpu() and backward_op.support_gpu()):
return
outputs = backward_op.outputs()
out_names = [item for k in outputs for item in outputs[k]]
+ out_names = filter(lambda x: x != empty_var_name(), out_names)
cpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
out_names, core.CPUPlace())
gpu_grads = self.__get_gradient(forward_op, backward_op, input_value,
@@ -242,6 +261,7 @@ class GradientChecker(unittest.TestCase):
output_name,
no_grad_set=None,
only_cpu=False,
+ in_place=False,
max_relative_error=0.005):
"""
:param forward_op: used to create backward_op
@@ -274,7 +294,8 @@ class GradientChecker(unittest.TestCase):
# get numerical gradients
numeric_grads = [
- get_numeric_gradient(forward_op, input_vars, output_name, name)
+ get_numeric_gradient(
+ forward_op, input_vars, output_name, name, in_place=in_place)
for name in inputs_to_check
]
diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py
index 9a0b109850e92c66e69f74c5cd0853a09b5551a1..a68f302f9c344bf6d63e8d9b48836d69338c3d0b 100644
--- a/python/paddle/v2/framework/tests/mnist.py
+++ b/python/paddle/v2/framework/tests/mnist.py
@@ -181,7 +181,7 @@ images = data_layer(name='pixel', dims=[BATCH_SIZE, 784])
labels = data_layer(name='label', dims=[BATCH_SIZE])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
-predict = fc_layer(net=forward_net, input=fc2, size=100, act="softmax")
+predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
cost = cross_entropy_layer(net=forward_net, input=predict, label=labels)
init_net.complete_add_op(True)
diff --git a/python/paddle/v2/framework/tests/test_gather_op.py b/python/paddle/v2/framework/tests/test_gather_op.py
index e86898304252d08be718e40fed46c5e921596af7..e3de3fd0a1dddb3edb0de5987bd71d8a176d97ef 100644
--- a/python/paddle/v2/framework/tests/test_gather_op.py
+++ b/python/paddle/v2/framework/tests/test_gather_op.py
@@ -21,12 +21,9 @@ class TestGatherOp(unittest.TestCase):
class TestGatherGradOp(GradientChecker):
def test_gather_grad(self):
- print 'creating op'
op = create_op("gather")
- print 'creating op done'
xnp = numpy.random.random((10, 20)).astype("float32")
inputs = {'X': xnp, 'Index': numpy.array([1, 3, 5]).astype("int32")}
- print 'correct before check gradient'
self.check_grad(op, inputs, set("X"), "Out")
diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..19eb464baa555fb67a994f3cfb4d3ed628367c73
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_lookup_table.py
@@ -0,0 +1,31 @@
+import unittest
+import numpy as np
+from op_test_util import OpTestMeta
+from gradient_checker import GradientChecker, create_op
+
+
+class TestSigmoidOp(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = 'lookup_table'
+ table = np.random.random((17, 31)).astype('float32')
+ ids = np.random.randint(0, 17, 4).astype('int32')
+ self.inputs = {'W': table, 'Ids': ids}
+ self.outputs = {'Out': table[ids]}
+
+
+class TestSigmoidGradOp(GradientChecker):
+ def test_grad(self):
+ op = create_op('lookup_table')
+ table = np.random.random((17, 31)).astype('float32')
+ ids = np.random.randint(0, 17, 4).astype('int32')
+ inputs = {'W': table, 'Ids': ids}
+ # comapre gradients
+ self.compare_grad(op, inputs, set(['Ids']))
+ # check gradients
+ self.check_grad(op, inputs, set('W'), 'Out')
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_scatter_op.py b/python/paddle/v2/framework/tests/test_scatter_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1f9444889372104e39ded78fc7207a59b80a293
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_scatter_op.py
@@ -0,0 +1,38 @@
+import unittest
+from op_test_util import OpTestMeta
+from gradient_checker import GradientChecker, create_op
+import numpy
+import paddle.v2.framework.core as core
+from paddle.v2.framework.op import Operator
+
+
+class TestScatterOp(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "scatter"
+ ref_np = numpy.ones((3, 3)).astype("float32")
+ index_np = numpy.array([1, 2]).astype("int32")
+ updates_np = numpy.random.random((2, 3)).astype("float32")
+ output_np = numpy.copy(ref_np)
+ output_np[index_np] += updates_np
+ self.inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
+ self.outputs = {'Out': output_np}
+
+
+class TestScatterGradOp(GradientChecker):
+ def test_scatter_grad(self):
+ op = create_op("scatter")
+ # test data setup
+ ref_np = numpy.ones((3, 10)).astype("float32")
+ index_np = numpy.array([1, 2]).astype("int32")
+ updates_np = numpy.random.random((2, 10)).astype("float32")
+ output_np = numpy.copy(ref_np)
+ output_np[index_np] += updates_np
+ inputs = {'Ref': ref_np, 'Index': index_np, 'Updates': updates_np}
+ self.check_grad(
+ op, inputs, set(["Updates", "Ref"]), "Out", in_place=True)
+
+
+if __name__ == "__main__":
+ unittest.main()