diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md index 85c222163c8e65d1a4f3a7460cb9e7a7d27aff65..0d29865447f77601b66efe44383299f686a8bef9 100644 --- a/doc/howto/dev/new_op_cn.md +++ b/doc/howto/dev/new_op_cn.md @@ -30,8 +30,8 @@ -------------- | :---------------------- OpProtoMake定义 | `.cc`文件,Backward Op不需要定义OpProtoMake Op定义 | `.cc`文件 -Kernel实现 | CPU、GPU共享Kernel在`.h`文件,否则,CPU可以在`.cc`文件,GPU可在`.cu`文件。 -注册Op | Op注册在`.cc`文件;Kernel注册CPU在`.cc`文件,GPU在`.cu`文件 +Kernel实现 | CPU、GPU共享Kernel实现在`.h`文件中,否则,CPU 实现在`.cc`文件中,GPU 实现在`.cu`文件中。 +注册Op | Op注册实现在`.cc`文件;Kernel注册CPU实现在`.cc`文件中,GPU实现在`.cu`文件中 实现新的op都添加至目录[paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators)下,文件命名以`*_op.h`(如有) 、 `*_op.cc` 、`*_op.cu`(如有)结尾。 @@ -171,7 +171,9 @@ class MulKernel : public framework::OpKernel { `MulKernel`需要重写`Compute`接口,该接口参数为`const framework::ExecutionContext& context`, `ExecutionContext`相比`InferShapeContext`增加了设备类型,同样可获取到输入输出和属性参数,`Compute`函数里写具体实现时。 -注意,不同设备(CPU、GPU)共享一个Op定义,是否则共享同一个`OpKernel`,取决于`Compute`调用的函数是否支持不同设备。`MulOp`的CPU、GPU实现共享同一个`Kernel`,`OpKernel`不共享的例子可以参考[`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43)。 +注意,不同设备(CPU、GPU)共享一个Op定义,是否则共享同一个`OpKernel`,取决于`Compute`调用的函数是否支持不同设备。`MulOp`的CPU、GPU实现共享同一个`Kernel`,`OpKernel`不共享的例子可以参考[`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43)。 + +为了使得`OpKernel`的计算过程书写较为简单,CPU、GPU的代码可以复用,我们通常借助Eigen unsupported Tensor模块来实现。关于在paddle中如何使用Eigen库,请参考对应的使用[文档](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md) 到此前向Op实现完成,需要在`.cc`文件中注册该op和kernel。反向Op类的定义和Kernel定义与前向Op类似,这里不再重复。但注意,反向Op没有`ProtoMaker`。 @@ -191,9 +193,12 @@ REGISTER_OP_CPU_KERNEL(mul_grad, - `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。 - `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。 -在 `.cu`文件中注册GPU Kernel。 +在 `.cu`文件中注册GPU Kernel。请注意,如果GPU Kernel的实现是基于Eigen unsupported模块,那么在 `.cu`的最前面请加上宏定义 `#define EIGEN_USE_GPU` ```cpp +// if use Eigen unsupported module before include head files +#define EIGEN_USE_GPU + namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); REGISTER_OP_GPU_KERNEL(mul_grad, @@ -280,28 +285,50 @@ class TestMulOp(unittest.TestCase): 反向Op单测继承自`GradientChecker`,而`GradientChecker`集成自`unittest.TestCase`,所以反向单测函数需要`test_`开头。 - ```python - class MulGradOpTest(GradientChecker): - def test_mul(self): - op = create_op("mul") - inputs = { +```cpp +class TestMulGradOp(GradientChecker): + def setUp(self): + self.op = create_op("mul") + self.inputs = { 'X': np.random.random((32, 84)).astype("float32"), 'Y': np.random.random((84, 100)).astype("float32") } - self.compare_grad(op, inputs) + + def test_cpu_gpu_compare(self): + self.compare_grad(self.op, self.inputs) + + def test_normal(self): # mul op will enlarge the relative error self.check_grad( - op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) - ``` + self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5) + + def test_ignore_x(self): + self.check_grad( + self.op, + self.inputs, ["Y"], + "Out", + max_relative_error=0.5, + no_grad_set={"X"}) + + def test_ignore_y(self): + self.check_grad( + self.op, + self.inputs, ["X"], + "Out", + max_relative_error=0.5, + no_grad_set={"Y"}) +``` + +下面解释一些关键的地方: - 调用`create_op("mul")`创建反向Op对应的前向Op。 - - 定义输入`inputs`。 - 调用`compare_grad`函数对比CPU、GPU计算结果。 - - 调用`check_grad`检查梯度稳定性,这里采用数值法检测梯度正确性。 - - 第一个参数`op` : 前向op。 - - 第二个参数`inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。 - - 第三个参数`set(["X", "Y"])` : 指定对输入变量`X`、`Y`做梯度检测。 + - `test_normal`中调用`check_grad`检查梯度稳定性,这里采用数值法检测梯度正确性。 + - 第一个参数`self.op` : 前向Op。 + - 第二个参数`self.inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。 + - 第三个参数`["X", "Y"]` : 指定对输入变量`X`、`Y`做梯度检测。 - 第四个参数`"Out"` : 指定前向网络最终的输出目标变量`Out` + - `test_ignore_x`和`test_ignore_y`分支测试只需要计算一个输入梯度的情况。 ### 编译和执行单元测试 diff --git a/doc/howto/dev/use_eigen_cn.md b/doc/howto/dev/use_eigen_cn.md new file mode 100644 index 0000000000000000000000000000000000000000..1367323b71277984834d9d4f0d9bea0f69478479 --- /dev/null +++ b/doc/howto/dev/use_eigen_cn.md @@ -0,0 +1,146 @@ +## 在Paddle中如何使用Eigen + +神经网络本质上是一个计算图,计算需要的数据存放在`Tensor`中,而计算过程是由`Operartor`来描述的。在执行时,`Operator`调用对应`OpKernel`中的`Compute`接口,实现对`Tensor`的操作。 + + +### Eigen Tensor模块 + +Eigen Tensor模块对element-wise计算提供了强大的支持,并且书写一份代码,可以同时在CPU、GPU执行。但Eigen Tensor是一个正在开发中的模块,因此可能测试不够完备,文档较少。 + +关于Eigen Tensor模块的详细介绍请参考[文档1](https://github.com/RLovelett/eigen/blob/master/unsupported/Eigen/CXX11/src/Tensor/README.md) 和[文档2](https://bitbucket.org/eigen/eigen/src/default/unsupported/Eigen/CXX11/src/Tensor/README.md) + + +### paddle::framework::Tensor + +Paddle Tensor定义在framework目录下,其主要接口如下: + +```cpp +class Tensor { + public: + /*! Return a pointer to mutable memory block. */ + template + inline T* data(); + + /** + * @brief Return a pointer to mutable memory block. + * @note If not exist, then allocation. + */ + template + inline T* mutable_data(platform::Place place); + + /** + * @brief Return a pointer to mutable memory block. + * + * @param[in] dims The dimensions of the memory block. + * @param[in] place The place of the memory block. + * + * @note If not exist, then allocation. + */ + template + inline T* mutable_data(DDim dims, platform::Place place); + + /*! Resize the dimensions of the memory block. */ + inline Tensor& Resize(const DDim& dims); + + /*! Return the dimensions of the memory block. */ + inline const DDim& dims() const; + + private: + /*! holds the memory block if allocated. */ + std::shared_ptr holder_; + + /*! points to dimensions of memory block. */ + DDim dim_; +}; +``` + +`Placeholder`的作用是延迟分配内存,即我们可以先定义一个Tensor,然后使用Resize接口设置Tensor的大小,最后再调用mutable_data接口分配实际的内存。 + +```cpp +paddle::framework::Tensor t; +paddle::platform::CPUPlace place; +// set size first +t.Resize({2, 3}); +// allocate memory on CPU later +t.mutable_data(place); +``` + +### paddle::framework::Tensor使用样例 +下面以AddOp为例说明Tensor的使用过程: + +- InferShape + +在运行神经网络计算图时,我们先调用每个`Operator`的`InferShape`接口,根据输入Tensor的大小来设置输出Tensor的大小,`Resize`接口会被调用。 + +```cpp +void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), + ctx.Input("Y")->dims(), + "Two input of Add Op's dimension must be same."); + ctx.Output("Out")->Resize(ctx.Input("X")->dims()); +} +``` + + +- Run + +`Operator`的`Run`接口最终会调用对应`OpKernel`的`Compute`接口,在这时真正的分配内存,`mutable_data`接口会被调用。 + +```cpp +void Compute(const framework::ExecutionContext& context) const override { + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Y"); + auto* output = context.Output("Out"); + + output->mutable_data(context.GetPlace()); + + auto x = EigenVector::Flatten(*input0); + auto y = EigenVector::Flatten(*input1); + auto z = EigenVector::Flatten(*output); + + auto place = context.GetEigenDevice(); + + z.device(place) = x + y; +} +``` + + +### paddle::framework::Tensor到EigenTensor的转换 + +如上一小节所示,在具体的计算中,我们需要先把输入Tensor和输出Tensor转换为Eigen支持的格式。我们在[eigen.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen.h)中提供了一些全局函数用来实现paddle::framework::Tensor到EigenTensor/EigenMatrix/EigenVector/EigenScalar的转换。 + +以EigenTensor为例,做一个介绍 + +```cpp +Tensor t; +float* p = t.mutable_data(make_ddim({1, 2, 3}), platform::CPUPlace()); +for (int i = 0; i < 1 * 2 * 3; i++) { + p[i] = static_cast(i); +} + +EigenTensor::Type et = EigenTensor::From(t); +``` + +From是EigenTensor模板提供的一个接口,可以实现从paddle::framework::Tensor到对EigenTensor的转换。由于Tensor的rank是模板参数,因此在转换时需要显示的指定。 + +在Eigen中,不同rank的Tensor是不同类型,Vector是rank为1的Tensor。需要额外注意的是,EigenVector::From方法是把paddle中的一维Tensor转为Eigen的一维Tensor,在这里用EigenVector来表示;而EigenVector::Flatten方法是把paddle中的一个Tensor进行reshape操作,压扁成为Eigen的一维Tensor,类型仍然为EigenVector。 + +更多的转换方法请参考eigen_test.cc中的[单元测试](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen_test.cc)。 + + + +### 实现计算 + +当需要完成计算时,我们需要等式左边的EigenTensor调用device接口。在这里需要注意的是,这里的EigenTensor之间的运算只是改变了原有Tensor中的数据,而不会改变原有Tensor的shape信息。 + +```cpp +auto x = EigenVector::Flatten(*input0); +auto y = EigenVector::Flatten(*input1); +auto z = EigenVector::Flatten(*output); +auto place = context.GetEigenDevice(); +z.device(place) = x + y; +``` + +在这段代码中,input0/input1/output可以是任意维度的Tensor。我们调用了EigenVector的Flatten接口,把任意维度的Tensor转为了一维的EigenVector。而在计算结束之后,input0/input1/output的原有shape信息不变。如果想改变原有Tensor的shape信息,可以调用Resize接口进行改变。 + +由于Eigen Tensor模块的文档较少,我们可以参考TensorFlow的[kernels](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/kernels)模块下的相关`OpKernel`的计算代码。 diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 9eb07acdff1d00dd926f1cee9c24f9f151006d7e..27132eaa0b3b0666fc042faf052dac2e169ba9e7 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -43,6 +43,10 @@ template <> AttrType AttrTypeID>() { return STRINGS; } +template <> +AttrType AttrTypeID>>() { + return INT_PAIRS; +} Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { switch (attr_desc.type()) { @@ -76,6 +80,14 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { } return val; } + case paddle::framework::AttrType::INT_PAIRS: { + std::vector> val(attr_desc.int_pairs_size()); + for (int i = 0; i < attr_desc.int_pairs_size(); ++i) { + val[i].first = attr_desc.int_pairs(i).first(); + val[i].second = attr_desc.int_pairs(i).second(); + } + return val; + } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); return boost::blank(); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 08b47cabd4c2225c50022bd35734dcc2663324d6..071879a9d453377ccc2e9e71b62e8568a7ef1c9b 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -28,7 +28,8 @@ namespace paddle { namespace framework { typedef boost::variant, - std::vector, std::vector> + std::vector, std::vector, + std::vector>> Attribute; typedef std::unordered_map AttributeMap; diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index ae44a1ffd45dacdc44a72edc630e771e7a2f2990..368136a9729dd2c745cc71bc391031e0a390fc87 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -22,8 +22,14 @@ enum AttrType { INTS = 3; FLOATS = 4; STRINGS = 5; + INT_PAIRS = 6; } +message IntPair { + required int32 first = 1; + required int32 second = 2; +}; + // OpDesc describes an instance of a C++ framework::OperatorBase // derived class type. message OpDesc { @@ -37,6 +43,7 @@ message OpDesc { repeated int32 ints = 6; repeated float floats = 7; repeated string strings = 8; + repeated IntPair int_pairs = 9; }; message Var { diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 50c45919c53af22665feeeebe753da283ded2b0c..b43f6a8cc56fdd2dc483bef303cf1213b171a5e4 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -174,36 +174,4 @@ TEST(OpRegistry, CustomChecker) { op->Run(scope, dev_ctx); int test_attr = op->GetAttr("test_attr"); ASSERT_EQ(test_attr, 4); -} - -class TestAttrProtoMaker : public pd::OpProtoAndCheckerMaker { - public: - TestAttrProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("scale", "scale of test op"); - AddAttr("scale", "scale of test op"); - } -}; - -TEST(ProtoMaker, DuplicatedAttr) { - pd::OpProto op_proto; - pd::OpAttrChecker op_checker; - auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); -} - -class TestInOutProtoMaker : public pd::OpProtoAndCheckerMaker { - public: - TestInOutProtoMaker(pd::OpProto* proto, pd::OpAttrChecker* op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of test op"); - AddInput("input", "input of test op"); - } -}; - -TEST(ProtoMaker, DuplicatedInOut) { - pd::OpProto op_proto; - pd::OpAttrChecker op_checker; - auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); - ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); -} +} \ No newline at end of file diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index f7c9e6b196a9d63c91d83fb6d985472a4e8976c4..8a1970c7a8aa5f76abed49bfde445fc743544e66 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -263,4 +263,38 @@ TEST(Operator, Clone) { OperatorClone a("ABC", {}, {}, {}); auto b = a.Clone(); ASSERT_EQ(a.Type(), b->Type()); +} + +class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { + public: + TestAttrProtoMaker(paddle::framework::OpProto* proto, + paddle::framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("scale", "scale of test op"); + AddAttr("scale", "scale of test op"); + } +}; + +TEST(ProtoMaker, DuplicatedAttr) { + paddle::framework::OpProto op_proto; + paddle::framework::OpAttrChecker op_checker; + auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker); + ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); +} + +class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker { + public: + TestInOutProtoMaker(paddle::framework::OpProto* proto, + paddle::framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddInput("input", "input of test op"); + } +}; + +TEST(ProtoMaker, DuplicatedInOut) { + paddle::framework::OpProto op_proto; + paddle::framework::OpAttrChecker op_checker; + auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker); + ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet); } \ No newline at end of file diff --git a/paddle/gserver/layers/Conv3DLayer.cpp b/paddle/gserver/layers/Conv3DLayer.cpp index 7cc9937cce37cbbc4640fbb88312841c23b757c0..3887aa58b283d319c5b9afec3a38ad676669a8d1 100644 --- a/paddle/gserver/layers/Conv3DLayer.cpp +++ b/paddle/gserver/layers/Conv3DLayer.cpp @@ -42,10 +42,10 @@ bool Conv3DLayer::init(const LayerMap &layerMap, if (sharedBiases_) { CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); biases_ = - std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); } else { biases_ = - std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); } } return true; @@ -224,20 +224,31 @@ void Conv3DLayer::bpropData(int i) { } void Conv3DLayer::bpropBiases() { + MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(), + 1, + biases_->getWGrad()->getElementCnt(), + false, + useGpu_); MatrixPtr outGradMat = getOutputGrad(); + if (this->sharedBiases_) { - biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); + biases->collectSharedBias(*outGradMat, 1.0f); } else { - biases_->getWGrad()->collectBias(*outGradMat, 1.0f); + biases->collectBias(*outGradMat, 1.0f); } } void Conv3DLayer::addBias() { MatrixPtr outMat = getOutputValue(); + MatrixPtr bias = Matrix::create(biases_->getW()->getData(), + 1, + biases_->getW()->getElementCnt(), + false, + useGpu_); if (this->sharedBiases_) { - outMat->addSharedBias(*(biases_->getW()), 1.0f); + outMat->addSharedBias(*(bias), 1.0f); } else { - outMat->addBias(*(biases_->getW()), 1.0f); + outMat->addBias(*(bias), 1.0f); } } diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp index 7d5c772c89d260264a59f4cc4439bb8a44c605a4..2838980a973d3dbcce9716f21f2ea07e3a2fa660 100644 --- a/paddle/gserver/layers/DeConv3DLayer.cpp +++ b/paddle/gserver/layers/DeConv3DLayer.cpp @@ -42,10 +42,10 @@ bool DeConv3DLayer::init(const LayerMap &layerMap, if (sharedBiases_) { CHECK_EQ((size_t)numFilters_, biasParameter_->getSize()); biases_ = - std::unique_ptr(new Weight(1, numFilters_, biasParameter_)); + std::unique_ptr(new Weight(numFilters_, 1, biasParameter_)); } else { biases_ = - std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + std::unique_ptr(new Weight(getSize(), 1, biasParameter_)); } } return true; @@ -191,21 +191,31 @@ void DeConv3DLayer::bpropWeights(int i) {} void DeConv3DLayer::bpropData(int i) {} void DeConv3DLayer::bpropBiases() { + MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(), + 1, + biases_->getWGrad()->getElementCnt(), + false, + useGpu_); const MatrixPtr &outGradMat = getOutputGrad(); if (this->sharedBiases_) { - biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f); + biases->collectSharedBias(*outGradMat, 1.0f); } else { - biases_->getWGrad()->collectBias(*outGradMat, 1.0f); + biases->collectBias(*outGradMat, 1.0f); } } void DeConv3DLayer::addBias() { MatrixPtr outMat = getOutputValue(); + MatrixPtr bias = Matrix::create(biases_->getW()->getData(), + 1, + biases_->getW()->getElementCnt(), + false, + useGpu_); if (this->sharedBiases_) { - outMat->addSharedBias(*(biases_->getW()), 1.0f); + outMat->addSharedBias(*(bias), 1.0f); } else { - outMat->addBias(*(biases_->getW()), 1.0f); + outMat->addBias(*(bias), 1.0f); } } diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h index 4da8079b91624c3510cae89fd599a7035a4c7477..877b36cef4ea9cdaaaf37c97d5e5bfce55b91436 100644 --- a/paddle/operators/lookup_table_op.h +++ b/paddle/operators/lookup_table_op.h @@ -30,12 +30,12 @@ class LookupTableKernel : public framework::OpKernel { auto ids_t = context.Input("Ids"); // int tensor auto output_t = context.Output("Out"); // float tensor - size_t N = table_t->dims()[0]; - size_t D = table_t->dims()[1]; + int N = table_t->dims()[0]; + int D = table_t->dims()[1]; auto ids = ids_t->data(); auto table = table_t->data(); auto output = output_t->mutable_data(context.GetPlace()); - for (size_t i = 0; i < product(ids_t->dims()); ++i) { + for (ssize_t i = 0; i < product(ids_t->dims()); ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); @@ -51,8 +51,8 @@ class LookupTableGradKernel : public framework::OpKernel { auto d_output_t = context.Input(framework::GradVarName("Out")); auto d_table_t = context.Output(framework::GradVarName("W")); - size_t N = d_table_t->dims()[0]; - size_t D = d_table_t->dims()[1]; + int N = d_table_t->dims()[0]; + int D = d_table_t->dims()[1]; auto ids = ids_t->data(); const T* d_output = d_output_t->data(); T* d_table = d_table_t->mutable_data(context.GetPlace()); @@ -61,10 +61,10 @@ class LookupTableGradKernel : public framework::OpKernel { t.device(context.GetEigenDevice()) = t.constant(static_cast(0)); - for (size_t i = 0; i < product(ids_t->dims()); ++i) { + for (ssize_t i = 0; i < product(ids_t->dims()); ++i) { PADDLE_ENFORCE_LT(ids[i], N); PADDLE_ENFORCE_GE(ids[i], 0); - for (size_t j = 0; j < D; ++j) { + for (int j = 0; j < D; ++j) { d_table[ids[i] * D + j] += d_output[i * D + j]; } } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 5b8b5f6c118cbe213e1783256a940dff6fdccc46..28a47cdff2e9b7a965ff9f99e787bb8315010823 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -75,8 +75,8 @@ class MulOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE(y_dims[1] == out_dims[1], "Out@GRAD M X N must equal to Y dims 1, N "); - x_grad->Resize(x_dims); - y_grad->Resize(y_dims); + if (x_grad) x_grad->Resize(x_dims); + if (y_grad) y_grad->Resize(y_dims); } }; diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 8facc0281449785bf40726f23ca2fd5d166ff272..05a79e13b3470e39a5ebd0394ba05629553a5075 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,13 +31,13 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* Z = context.Output("Out"); - Z->mutable_data(context.GetPlace()); + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* z = context.Output("Out"); + z->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); - math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + math::matmul(*x, false, *y, false, 1, z, 0, device_context); } }; @@ -45,20 +45,24 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* X = ctx.Input("X"); - auto* Y = ctx.Input("Y"); - auto* dOut = ctx.Input(framework::GradVarName("Out")); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dX = ctx.Output(framework::GradVarName("X")); - auto* dY = ctx.Output(framework::GradVarName("Y")); - dX->mutable_data(ctx.GetPlace()); - dY->mutable_data(ctx.GetPlace()); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); auto* device_context = const_cast(ctx.device_context_); - // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N - math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); - // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K - math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); + if (dx) { + dx->mutable_data(ctx.GetPlace()); + // dx = dout * y'. dx: M x K, dout : M x N, y : K x N + math::matmul(*dout, false, *y, true, 1, dx, 0, device_context); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + // dy = x' * dout. dy K x N, dout : M x N, x : M x K + math::matmul(*x, true, *dout, false, 1, dy, 0, device_context); + } } }; diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 6825dce332adc0dc11dda187d1bd367875b8603e..30b4b404315a9f041e21d79b75fd06307e33f7f9 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -64,8 +64,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel { auto dims0 = ctx.Input("X")->dims(); auto dims1 = ctx.Input("b")->dims(); PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") - ctx.Output(framework::GradVarName("X"))->Resize(dims0); - ctx.Output(framework::GradVarName("b"))->Resize(dims1); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *db = ctx.Output(framework::GradVarName("b")); + if (dx) dx->Resize(dims0); + if (db) db->Resize(dims1); } }; diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 1cbd8bb31ad90a32d8a4e3bb59617d0b5384e470..4e926d9f2947f37b71e81c0fa592b0c66b19c640 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -51,20 +51,24 @@ template class RowwiseAddGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* dOut = context.Input(framework::GradVarName("Out")); - auto* dX = context.Output(framework::GradVarName("X")); + auto* dout = context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); auto* db = context.Output(framework::GradVarName("b")); - dX->mutable_data(context.GetPlace()); - db->mutable_data(context.GetPlace()); - auto OutGrad = EigenMatrix::From(*dOut); + auto out_grad = EigenMatrix::From(*dout); auto place = context.GetEigenDevice(); - EigenMatrix::From(*dX).device(place) = OutGrad; + if (dx) { + dx->mutable_data(context.GetPlace()); + EigenMatrix::From(*dx).device(place) = out_grad; + } - // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html - // colwise add - Eigen::array dims{{0}}; /* dimension to reduce */ - EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); + if (db) { + db->mutable_data(context.GetPlace()); + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + // colwise add + Eigen::array dims{{0}}; /* dimension to reduce */ + EigenVector::Flatten(*db).device(place) = out_grad.sum(dims); + } } }; } // namespace operators diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 40c51a64c49bc064f55975ef6ced1d54070f1291..7d062ad67c048bc6bef68121f86334eb3f1efe92 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -24,7 +24,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE(ctx.Input("X")->dims().size() == 2UL, - "The input of softmax op must be matrix"); + "The input of softmax op must be a matrix."); ctx.Output("Y")->Resize(ctx.Input("X")->dims()); } }; @@ -34,9 +34,27 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { SoftmaxOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "input of softmax"); - AddOutput("Y", "output of softmax"); - AddComment("Softmax Op"); + AddInput("X", + "The input tensor of softmax. " + "2-D with shape [batch_size, input_feature_dimensions]."); + AddOutput("Y", "The normalized values with the same shape as X."); + AddComment(R"DOC( +The input of softmax operator is a 2-D tensor with shape N x K (N is the +batch_size, K is the dimension of input feature). The output tensor has the +same shape as the input tensor. + +For each row of the input tensor, the softmax operator squashes the +K-dimensional vector of arbitrary real values to a K-dimensional vector of real +values in the range [0, 1] that add up to 1. Specifically, it computes the +exponential of the given dimension and the sum of exponential values of all +the other dimensions in the K-dimensional vector input. Then the ratio of the +exponential of the given dimension and the sum of exponential values of all +the other dimensions is the output of the softmax operator. + +For each row `i` and each column `j` in X, we have: + Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j])) + +)DOC"); } }; diff --git a/python/paddle/trainer/PyDataProvider2.py b/python/paddle/trainer/PyDataProvider2.py index 7e305e2cd9fbe306368a44d08f7f66b4185ae2d2..248da4ae8d1fb24652625ae8fc9ef314a028b912 100644 --- a/python/paddle/trainer/PyDataProvider2.py +++ b/python/paddle/trainer/PyDataProvider2.py @@ -27,6 +27,14 @@ class SequenceType(object): SEQUENCE = 1 SUB_SEQUENCE = 2 + @classmethod + def tostring(cls, value): + for k in cls.__dict__: + if not k.startswith('__'): + if getattr(cls, k) == value: + return cls.__name__ + '.' + k + return 'INVALID(' + str(value) + ')' + # TODO(yuyang18): Add string data type here. class DataType(object): @@ -35,6 +43,14 @@ class DataType(object): SparseValue = 2 Index = 3 + @classmethod + def tostring(cls, value): + for k in cls.__dict__: + if not k.startswith('__'): + if getattr(cls, k) == value: + return cls.__name__ + '.' + k + return 'INVALID(' + str(value) + ')' + class CacheType(object): NO_CACHE = 0 # No cache at all @@ -69,6 +85,26 @@ class InputType(object): self.seq_type = seq_type self.type = tp + def __repr__(self): + """ + Return a human readable representation like 'InputType(dim=25921, + seq_type=SequenceType.NO_SEQUENCE, type=DataType.Dense)' + """ + repr_str = type(self).__name__ + repr_str += '(' + serialize_func_map = { + 'dim': repr, + 'seq_type': SequenceType.tostring, + 'type': DataType.tostring + } + for idx, k in enumerate(self.__slots__): + if idx != 0: + repr_str += ', ' + repr_str += ( + k + '=' + serialize_func_map.get(k, repr)(getattr(self, k))) + repr_str += ')' + return repr_str + def dense_slot(dim, seq_type=SequenceType.NO_SEQUENCE): """ diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py index e7e932f6fe46f2db8de37cecfd9875b310bdcbe5..0349407a851ebb48f69d7daef7a318cf348aad5d 100644 --- a/python/paddle/v2/framework/op.py +++ b/python/paddle/v2/framework/op.py @@ -94,9 +94,14 @@ class OpDescCreationMethod(object): new_attr.floats.extend(user_defined_attr) elif attr.type == framework_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) + elif attr.type == framework_pb2.INT_PAIRS: + for p in user_defined_attr: + pair = new_attr.pairs.add() + pair.first = p[0] + pair.second = p[1] else: raise NotImplementedError("Not support attribute type " + - attr.type) + str(attr.type)) return op_desc diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 518f828bacd60e7cb8375b22c6c3296f9bfeb5ea..b8d7e4ea4329e50a086ae51451364a348ff8b360 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -286,6 +286,9 @@ class GradientChecker(unittest.TestCase): for no_grad in no_grad_set: if no_grad not in in_names: raise ValueError("no_grad should be in in_names") + if no_grad in inputs_to_check: + raise ValueError("no_grad should not be in inputs_to_check") + backward_op = core.Operator.backward(forward_op, no_grad_set) places = [core.CPUPlace()] @@ -301,7 +304,6 @@ class GradientChecker(unittest.TestCase): check_names = [grad_var_name(name) for name in inputs_to_check] for place in places: - # get analytical gradients according to different device analytic_grads = self.__get_gradient(forward_op, backward_op, input_vars, check_names, place) self.__assert_is_close(numeric_grads, analytic_grads, check_names, diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index ee0d81a64efcb81bae8b11b856c201a86da274e9..b58e4266d1588a4b6151f5f896537ded6ddd3896 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -16,16 +16,37 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} -class MulGradOpTest(GradientChecker): - def test_mul(self): - op = create_op("mul") - inputs = { +class TestMulGradOp(GradientChecker): + def setUp(self): + self.op = create_op("mul") + self.inputs = { 'X': np.random.random((32, 84)).astype("float32"), 'Y': np.random.random((84, 100)).astype("float32") } + + def test_cpu_gpu_compare(self): + self.compare_grad(self.op, self.inputs) + + def test_normal(self): # mul op will enlarge the relative error self.check_grad( - op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5) + self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5) + + def test_ignore_x(self): + self.check_grad( + self.op, + self.inputs, ["Y"], + "Out", + max_relative_error=0.5, + no_grad_set={"X"}) + + def test_ignore_y(self): + self.check_grad( + self.op, + self.inputs, ["X"], + "Out", + max_relative_error=0.5, + no_grad_set={"Y"}) # TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index 45d569da29d13cf8e2a3cb9d67c2d01e8b365453..2ddb85e2e7a98a08bd1d6e24e6f812f6021142e8 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -16,14 +16,22 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} -class RowwiseAddGradOpTest(GradientChecker): - def test_rowwise_add(self): - op = create_op("rowwise_add") - inputs = { +class TestRowwiseAddGradOp(GradientChecker): + def setUp(self): + self.op = create_op("rowwise_add") + self.inputs = { "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"), "b": np.random.uniform(0.1, 1, [10]).astype("float32") } - self.check_grad(op, inputs, set(["X", "b"]), "Out") + + def test_normal(self): + self.check_grad(self.op, self.inputs, ["X", "b"], "Out") + + def test_ignore_b(self): + self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"}) + + def test_ignore_x(self): + self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"}) if __name__ == '__main__':