提交 69fbc542 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_add_axis

......@@ -227,6 +227,12 @@ make mul_op
USE_CPU_ONLY_OP(gather);
```
如果OP不带Kernel,则使用`USE_NO_KENREL_OP`:
```
USE_NO_KENREL_OP(recurrent);
```
使用`USE_OP`告知编译器需要链接该Op的目标文件,具体解释参考[代码注释](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/op_registry.h#L81)。
......@@ -280,28 +286,50 @@ class TestMulOp(unittest.TestCase):
反向Op单测继承自`GradientChecker`,而`GradientChecker`集成自`unittest.TestCase`,所以反向单测函数需要`test_`开头。
```
class MulGradOpTest(GradientChecker):
def test_mul(self):
op = create_op("mul")
inputs = {
```
class TestMulGradOp(GradientChecker):
def setUp(self):
self.op = create_op("mul")
self.inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
self.compare_grad(op, inputs)
def test_cpu_gpu_compare(self):
self.compare_grad(self.op, self.inputs)
def test_normal(self):
# mul op will enlarge the relative error
self.check_grad(
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
```
self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
def test_ignore_x(self):
self.check_grad(
self.op,
self.inputs, ["Y"],
"Out",
max_relative_error=0.5,
no_grad_set={"X"})
def test_ignore_y(self):
self.check_grad(
self.op,
self.inputs, ["X"],
"Out",
max_relative_error=0.5,
no_grad_set={"Y"})
```
下面解释一些关键的地方:
- 调用`create_op("mul")`创建反向Op对应的前向Op。
- 定义输入`inputs`
- 调用`compare_grad`函数对比CPU、GPU计算结果。
- 调用`check_grad`检查梯度稳定性,这里采用数值法检测梯度正确性。
- 第一个参数`op` : 前向op。
- 第二个参数`inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。
- 第三个参数`set(["X", "Y"])` : 指定对输入变量`X``Y`做梯度检测。
- `test_normal`调用`check_grad`检查梯度稳定性,这里采用数值法检测梯度正确性。
- 第一个参数`self.op` : 前向Op。
- 第二个参数`self.inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。
- 第三个参数`["X", "Y"]` : 指定对输入变量`X``Y`做梯度检测。
- 第四个参数`"Out"` : 指定前向网络最终的输出目标变量`Out`
- `test_ignore_x``test_ignore_y`分支测试只需要计算一个输入梯度的情况。
### 编译和执行
......
......@@ -182,7 +182,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
});
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent_op") {
if (forwardOp.Type() == "recurrent") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// this will result in infinite loop.
const auto& rnnop =
......
......@@ -199,6 +199,8 @@ class OpKernelRegistrar : public Registrar {
USE_OP_DEVICE_KERNEL(op_type, GPU)
#endif
#define USE_NO_KERNEL_OP(op_type) USE_OP_ITSELF(op_type);
#define USE_CPU_ONLY_OP(op_type) \
USE_OP_ITSELF(op_type); \
USE_OP_DEVICE_KERNEL(op_type, CPU);
......
......@@ -42,10 +42,10 @@ bool Conv3DLayer::init(const LayerMap &layerMap,
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
return true;
......@@ -224,20 +224,31 @@ void Conv3DLayer::bpropData(int i) {
}
void Conv3DLayer::bpropBiases() {
MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(),
1,
biases_->getWGrad()->getElementCnt(),
false,
useGpu_);
MatrixPtr outGradMat = getOutputGrad();
if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
biases->collectSharedBias(*outGradMat, 1.0f);
} else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f);
biases->collectBias(*outGradMat, 1.0f);
}
}
void Conv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue();
MatrixPtr bias = Matrix::create(biases_->getW()->getData(),
1,
biases_->getW()->getElementCnt(),
false,
useGpu_);
if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f);
outMat->addSharedBias(*(bias), 1.0f);
} else {
outMat->addBias(*(biases_->getW()), 1.0f);
outMat->addBias(*(bias), 1.0f);
}
}
......
......@@ -42,10 +42,10 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
if (sharedBiases_) {
CHECK_EQ((size_t)numFilters_, biasParameter_->getSize());
biases_ =
std::unique_ptr<Weight>(new Weight(1, numFilters_, biasParameter_));
std::unique_ptr<Weight>(new Weight(numFilters_, 1, biasParameter_));
} else {
biases_ =
std::unique_ptr<Weight>(new Weight(1, getSize(), biasParameter_));
std::unique_ptr<Weight>(new Weight(getSize(), 1, biasParameter_));
}
}
return true;
......@@ -191,21 +191,31 @@ void DeConv3DLayer::bpropWeights(int i) {}
void DeConv3DLayer::bpropData(int i) {}
void DeConv3DLayer::bpropBiases() {
MatrixPtr biases = Matrix::create(biases_->getWGrad()->getData(),
1,
biases_->getWGrad()->getElementCnt(),
false,
useGpu_);
const MatrixPtr &outGradMat = getOutputGrad();
if (this->sharedBiases_) {
biases_->getWGrad()->collectSharedBias(*outGradMat, 1.0f);
biases->collectSharedBias(*outGradMat, 1.0f);
} else {
biases_->getWGrad()->collectBias(*outGradMat, 1.0f);
biases->collectBias(*outGradMat, 1.0f);
}
}
void DeConv3DLayer::addBias() {
MatrixPtr outMat = getOutputValue();
MatrixPtr bias = Matrix::create(biases_->getW()->getData(),
1,
biases_->getW()->getElementCnt(),
false,
useGpu_);
if (this->sharedBiases_) {
outMat->addSharedBias(*(biases_->getW()), 1.0f);
outMat->addSharedBias(*(bias), 1.0f);
} else {
outMat->addBias(*(biases_->getW()), 1.0f);
outMat->addBias(*(bias), 1.0f);
}
}
......
......@@ -79,7 +79,7 @@ class MinusGradOp : public NetOp {
} // namespace paddle
USE_OP(scale);
USE_OP_ITSELF(identity);
USE_NO_KERNEL_OP(identity);
namespace ops = paddle::operators;
REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad,
ops::MinusGradOp<float>);
......
......@@ -107,8 +107,8 @@ class MulOpGrad : public framework::OperatorWithKernel {
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");
x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
}
};
......
......@@ -2,13 +2,13 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
you may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANy KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
......@@ -31,24 +31,24 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* X = context.Input<Tensor>("X");
const Tensor* Y = context.Input<Tensor>("Y");
const Tensor* x = context.Input<Tensor>("X");
const Tensor* y = context.Input<Tensor>("Y");
Tensor* Z = context.Output<Tensor>("Out");
const Tensor X_matrix =
X->dims().size() > 2
const Tensor x_matrix =
x->dims().size() > 2
? framework::FlattenToMatrix<T>(
*X, context.template GetAttr<int>("x_num_row_dims"))
: *X;
const Tensor Y_matrix =
Y->dims().size() > 2
*x, context.template GetAttr<int>("x_num_row_dims"))
: *x;
const Tensor y_matrix =
y->dims().size() > 2
? framework::FlattenToMatrix<T>(
*Y, context.template GetAttr<int>("y_num_row_dims"))
: *Y;
*y, context.template GetAttr<int>("y_num_row_dims"))
: *y;
Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(X_matrix, false, Y_matrix, false, 1, Z, 0,
math::matmul<Place, T>(x_matrix, false, y_matrix, false, 1, Z, 0,
device_context);
}
};
......@@ -59,34 +59,38 @@ class MulGradKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_row_dims = ctx.template GetAttr<int>("x_num_row_dims");
int y_num_row_dims = ctx.template GetAttr<int>("y_num_row_dims");
const Tensor* X = ctx.Input<Tensor>("X");
const Tensor* Y = ctx.Input<Tensor>("Y");
const Tensor X_matrix =
X->dims().size() > 2 ? framework::FlattenToMatrix<T>(*X, x_num_row_dims)
: *X;
const Tensor Y_matrix =
Y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*Y, y_num_row_dims)
: *Y;
const Tensor* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
const Tensor* x = ctx.Input<Tensor>("X");
const Tensor* y = ctx.Input<Tensor>("Y");
const Tensor x_matrix =
x->dims().size() > 2 ? framework::FlattenToMatrix<T>(*x, x_num_row_dims)
: *x;
const Tensor y_matrix =
y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*y, y_num_row_dims)
: *y;
const Tensor* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
Tensor dX_matrix = dX->dims().size() > 2
? framework::FlattenToMatrix<T>(*dX, x_num_row_dims)
: *dX;
Tensor dY_matrix = dY->dims().size() > 2
? framework::FlattenToMatrix<T>(*dY, y_num_row_dims)
: *dY;
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, Y_matrix, true, 1, &dX_matrix, 0,
device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(X_matrix, true, *dOut, false, 1, &dY_matrix, 0,
device_context);
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
Tensor dx_matrix = dx->dims().size() > 2 ? framework::FlattenToMatrix<T>(
*dx, x_num_row_dims)
: *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
math::matmul<Place, T>(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
device_context);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
Tensor dy_matrix = dy->dims().size() > 2 ? framework::FlattenToMatrix<T>(
*dy, y_num_row_dims)
: *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
math::matmul<Place, T>(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
device_context);
}
}
};
......
......@@ -235,5 +235,5 @@ RecurrentGradientOp::RecurrentGradientOp(
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
recurrent_op, paddle::operators::RecurrentOp,
recurrent, paddle::operators::RecurrentOp,
paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker);
......@@ -64,8 +64,10 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *db = ctx.Output<Tensor>(framework::GradVarName("b"));
if (dx) dx->Resize(dims0);
if (db) db->Resize(dims1);
}
};
......
......@@ -51,20 +51,24 @@ template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b"));
dX->mutable_data<T>(context.GetPlace());
db->mutable_data<T>(context.GetPlace());
auto OutGrad = EigenMatrix<T>::From(*dOut);
auto out_grad = EigenMatrix<T>::From(*dout);
auto place = context.GetEigenDevice<Place>();
EigenMatrix<T>::From(*dX).device(place) = OutGrad;
if (dx) {
dx->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*dx).device(place) = out_grad;
}
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
if (db) {
db->mutable_data<T>(context.GetPlace());
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = out_grad.sum(dims);
}
}
};
} // namespace operators
......
......@@ -39,12 +39,12 @@ USE_OP(sigmoid);
USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_OP_ITSELF(recurrent_op);
USE_NO_KERNEL_OP(recurrent);
USE_OP(gaussian_random);
USE_OP(uniform_random);
USE_OP(lookup_table);
USE_OP(scale);
USE_OP_ITSELF(identity);
USE_NO_KERNEL_OP(identity);
USE_OP(minus);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
......
......@@ -179,7 +179,7 @@ class OperatorFactory(object):
class __RecurrentOp__(object):
__proto__ = None
type = 'recurrent_op'
type = 'recurrent'
def __init__(self):
# cache recurrent_op's proto
......
......@@ -286,6 +286,9 @@ class GradientChecker(unittest.TestCase):
for no_grad in no_grad_set:
if no_grad not in in_names:
raise ValueError("no_grad should be in in_names")
if no_grad in inputs_to_check:
raise ValueError("no_grad should not be in inputs_to_check")
backward_op = core.Operator.backward(forward_op, no_grad_set)
places = [core.CPUPlace()]
......@@ -301,7 +304,6 @@ class GradientChecker(unittest.TestCase):
check_names = [grad_var_name(name) for name in inputs_to_check]
for place in places:
# get analytical gradients according to different device
analytic_grads = self.__get_gradient(forward_op, backward_op,
input_vars, check_names, place)
self.__assert_is_close(numeric_grads, analytic_grads, check_names,
......
......@@ -16,16 +16,37 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
class MulGradOpTest(GradientChecker):
def test_mul(self):
op = create_op("mul")
inputs = {
class TestMulGradOp(GradientChecker):
def setUp(self):
self.op = create_op("mul")
self.inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
def test_cpu_gpu_compare(self):
self.compare_grad(self.op, self.inputs)
def test_normal(self):
# mul op will enlarge the relative error
self.check_grad(
op, inputs, set(["X", "Y"]), "Out", max_relative_error=0.5)
self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
def test_ignore_x(self):
self.check_grad(
self.op,
self.inputs, ["Y"],
"Out",
max_relative_error=0.5,
no_grad_set={"X"})
def test_ignore_y(self):
self.check_grad(
self.op,
self.inputs, ["X"],
"Out",
max_relative_error=0.5,
no_grad_set={"Y"})
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
......
......@@ -16,14 +16,22 @@ class TestRowwiseAddOp(unittest.TestCase):
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
class RowwiseAddGradOpTest(GradientChecker):
def test_rowwise_add(self):
op = create_op("rowwise_add")
inputs = {
class TestRowwiseAddGradOp(GradientChecker):
def setUp(self):
self.op = create_op("rowwise_add")
self.inputs = {
"X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
"b": np.random.uniform(0.1, 1, [10]).astype("float32")
}
self.check_grad(op, inputs, set(["X", "b"]), "Out")
def test_normal(self):
self.check_grad(self.op, self.inputs, ["X", "b"], "Out")
def test_ignore_b(self):
self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"})
def test_ignore_x(self):
self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册