提交 31531ab5 编写于 作者: W wanghaoshuang

Add backward kernel

上级 8de04be7
...@@ -110,7 +110,7 @@ Vector<size_t> repeat_lod(Vector<size_t> data, Vector<size_t> starts, ...@@ -110,7 +110,7 @@ Vector<size_t> repeat_lod(Vector<size_t> data, Vector<size_t> starts,
size_t p = 0, start = 0, end = 0; size_t p = 0, start = 0, end = 0;
if (is_first == true) { if (is_first == true) {
for (size_t i = 0; i < times.size(); ++i) { for (size_t i = 0; i < times.size(); ++i) {
result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); result.push_back(result.back() + times[i] * (data[i + 1] - data[i]));
} }
} else { } else {
for (size_t i = 0; i < times.size(); ++i) { for (size_t i = 0; i < times.size(); ++i) {
......
...@@ -60,7 +60,8 @@ As an example: ...@@ -60,7 +60,8 @@ As an example:
Given: Given:
X = [1, 2 , 3] X.data = [1, 2 , 3, 4]
X.lod = [[0, 3, 4], [0, 1, 3, 4]]
and and
...@@ -69,8 +70,8 @@ repeat = 2 ...@@ -69,8 +70,8 @@ repeat = 2
then we get then we get
Out.data = [1, 1, 2, 2, 3, 3] Out.data = [1, 2, 3, 1, 2, 3, 4, 4]
Out.lod = [[0, 2, 4, 6]] Out.lod = [[0, 6, 8], [0, 3, 6, 7, 8], [0, 1, 3, 4, 6, 7, 8]]
)DOC"); )DOC");
} }
...@@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { ...@@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { ...@@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
} }
}; };
class SeqExpandOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* bind = new framework::OpDescBind();
bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
bind->SetType("seq_expand_grad");
return std::unique_ptr<framework::OpDescBind>(bind);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
REGISTER_OPERATOR(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker, seq_expand_grad, ops::SeqExpandOpGrad);
ops::SeqExpandOpGradMaker);
REGISTER_OPERATOR(seq_expand_grad, ops::SeqExpandOpGrad);
REGISTER_OP_CPU_KERNEL(seq_expand, REGISTER_OP_CPU_KERNEL(seq_expand,
ops::SeqExpandKernel<paddle::platform::CPUPlace, float>); ops::SeqExpandKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -93,9 +94,29 @@ template <typename Place, typename T> ...@@ -93,9 +94,29 @@ template <typename Place, typename T>
class SeqExpandGradKernel : public framework::OpKernel<T> { class SeqExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(context.GetPlace()); auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto out_lod = out->lod();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
auto d_out_dims = d_out->dims();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = framework::product(d_out_dims) / d_out_dims[0];
for (size_t i = 0; i < out->NumElements(); ++i) {
size_t ele_count = out_lod[0][i + 1] - out_lod[0][i];
size_t repeat = out->NumElements(0, i);
Eigen::TensorMap<Eigen::Tensor<const T, 2>> d_out_t(
d_out_data, static_cast<int>(repeat),
static_cast<int>((ele_count * element_len) / repeat));
Eigen::TensorMap<Eigen::Tensor<T, 1>> d_x_t(
d_x_data, static_cast<int>((ele_count * element_len) / repeat));
auto place = context.GetEigenDevice<Place>();
d_x_t.device(place) = d_out_t.sum(Eigen::array<int, 1>({0}));
d_out_data += (ele_count * element_len);
d_x_data += ((ele_count * element_len) / repeat);
}
} }
}; };
......
...@@ -68,12 +68,12 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,12 +68,12 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
"The level should be less than the level number of inputs.") "The level should be less than the level number of inputs.")
.SetDefault(0); .SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
The sequence_concat operator concatenates multiple LoDTensors. The sequence_concat operator concatenates multiple LoDTensors.
It only supports sequence (LoD Tensor with level number is 1) It only supports sequence (LoD Tensor with level number is 1)
or a nested sequence (LoD tensor with level number is 2) as its input. or a nested sequence (LoD tensor with level number is 2) as its input.
- Case1: - Case1:
If the axis is other than 0(here, axis is 1 and level is 1), If the axis is other than 0(here, axis is 1 and level is 1),
each input should have the same LoD information and the LoD each input should have the same LoD information and the LoD
information of the output keeps the same as the input. information of the output keeps the same as the input.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
...@@ -81,7 +81,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -81,7 +81,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4)
- Case2: - Case2:
If the axis is 0(here, leve is 0), the inputs are concatenated along If the axis is 0(here, leve is 0), the inputs are concatenated along
time steps, the LoD information of the output need to re-compute. time steps, the LoD information of the output need to re-compute.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
...@@ -94,7 +94,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -94,7 +94,7 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4)
LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4)
NOTE: The levels of all the inputs should be the same. NOTE: The levels of all the inputs should be the same.
)DOC"); )DOC");
} }
......
...@@ -246,9 +246,6 @@ class OpTest(unittest.TestCase): ...@@ -246,9 +246,6 @@ class OpTest(unittest.TestCase):
else: else:
actual = np.array(self.scope.find_var(out_name).get_tensor()) actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name] expect = self.outputs[out_name]
print "out_name: %s" % out_name
print "actual: %s" % actual
print "expcept: %s" % expect
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual, expect, atol=atol), actual, expect, atol=atol),
......
...@@ -62,9 +62,8 @@ class TestSeqExpand(OpTest): ...@@ -62,9 +62,8 @@ class TestSeqExpand(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
# def test_check_grad(self): self.check_grad(["X"], "Out")
# self.check_grad(["X"], "Out")
class TestSeqExpandCase1(TestSeqExpand): class TestSeqExpandCase1(TestSeqExpand):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册