提交 31531ab5 编写于 作者: W wanghaoshuang

Add backward kernel

上级 8de04be7
......@@ -110,7 +110,7 @@ Vector<size_t> repeat_lod(Vector<size_t> data, Vector<size_t> starts,
size_t p = 0, start = 0, end = 0;
if (is_first == true) {
for (size_t i = 0; i < times.size(); ++i) {
result.push_back(data.back() + times[i] * (data[i + 1] - data[i]));
result.push_back(result.back() + times[i] * (data[i + 1] - data[i]));
}
} else {
for (size_t i = 0; i < times.size(); ++i) {
......
......@@ -60,7 +60,8 @@ As an example:
Given:
X = [1, 2 , 3]
X.data = [1, 2 , 3, 4]
X.lod = [[0, 3, 4], [0, 1, 3, 4]]
and
......@@ -69,8 +70,8 @@ repeat = 2
then we get
Out.data = [1, 1, 2, 2, 3, 3]
Out.lod = [[0, 2, 4, 6]]
Out.data = [1, 2, 3, 1, 2, 3, 4, 4]
Out.lod = [[0, 6, 8], [0, 3, 6, 7, 8], [0, 1, 3, 4, 6, 7, 8]]
)DOC");
}
......@@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
......@@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
}
};
class SeqExpandOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* bind = new framework::OpDescBind();
bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
bind->SetType("seq_expand_grad");
return std::unique_ptr<framework::OpDescBind>(bind);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
ops::SeqExpandOpGradMaker);
REGISTER_OPERATOR(seq_expand_grad, ops::SeqExpandOpGrad);
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
seq_expand_grad, ops::SeqExpandOpGrad);
REGISTER_OP_CPU_KERNEL(seq_expand,
ops::SeqExpandKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -16,6 +16,7 @@
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace operators {
......@@ -93,9 +94,29 @@ template <typename Place, typename T>
class SeqExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(context.GetPlace());
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto out_lod = out->lod();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
auto d_out_dims = d_out->dims();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = framework::product(d_out_dims) / d_out_dims[0];
for (size_t i = 0; i < out->NumElements(); ++i) {
size_t ele_count = out_lod[0][i + 1] - out_lod[0][i];
size_t repeat = out->NumElements(0, i);
Eigen::TensorMap<Eigen::Tensor<const T, 2>> d_out_t(
d_out_data, static_cast<int>(repeat),
static_cast<int>((ele_count * element_len) / repeat));
Eigen::TensorMap<Eigen::Tensor<T, 1>> d_x_t(
d_x_data, static_cast<int>((ele_count * element_len) / repeat));
auto place = context.GetEigenDevice<Place>();
d_x_t.device(place) = d_out_t.sum(Eigen::array<int, 1>({0}));
d_out_data += (ele_count * element_len);
d_x_data += ((ele_count * element_len) / repeat);
}
}
};
......
......@@ -68,12 +68,12 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
"The level should be less than the level number of inputs.")
.SetDefault(0);
AddComment(R"DOC(
The sequence_concat operator concatenates multiple LoDTensors.
It only supports sequence (LoD Tensor with level number is 1)
The sequence_concat operator concatenates multiple LoDTensors.
It only supports sequence (LoD Tensor with level number is 1)
or a nested sequence (LoD tensor with level number is 2) as its input.
- Case1:
If the axis is other than 0(here, axis is 1 and level is 1),
each input should have the same LoD information and the LoD
each input should have the same LoD information and the LoD
information of the output keeps the same as the input.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
......@@ -81,7 +81,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4)
- Case2:
If the axis is 0(here, leve is 0), the inputs are concatenated along
If the axis is 0(here, leve is 0), the inputs are concatenated along
time steps, the LoD information of the output need to re-compute.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
......@@ -94,7 +94,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4)
LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4)
NOTE: The levels of all the inputs should be the same.
)DOC");
}
......
......@@ -246,9 +246,6 @@ class OpTest(unittest.TestCase):
else:
actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
print "out_name: %s" % out_name
print "actual: %s" % actual
print "expcept: %s" % expect
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
......
......@@ -62,9 +62,8 @@ class TestSeqExpand(OpTest):
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSeqExpandCase1(TestSeqExpand):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册