提交 31531ab5 编写于 作者: W wanghaoshuang

Add backward kernel

上级 8de04be7
...@@ -110,7 +110,7 @@ Vector<size_t> repeat_lod(Vector<size_t> data, Vector<size_t> starts, ...@@ -110,7 +110,7 @@ Vector<size_t> repeat_lod(Vector<size_t> data, Vector<size_t> starts,
size_t p = 0, start = 0, end = 0; size_t p = 0, start = 0, end = 0;
if (is_first == true) { if (is_first == true) {
for (size_t i = 0; i < times.size(); ++i) { for (size_t i = 0; i < times.size(); ++i) {
result.push_back(data.back() + times[i] * (data[i + 1] - data[i])); result.push_back(result.back() + times[i] * (data[i + 1] - data[i]));
} }
} else { } else {
for (size_t i = 0; i < times.size(); ++i) { for (size_t i = 0; i < times.size(); ++i) {
......
...@@ -60,7 +60,8 @@ As an example: ...@@ -60,7 +60,8 @@ As an example:
Given: Given:
X = [1, 2 , 3] X.data = [1, 2 , 3, 4]
X.lod = [[0, 3, 4], [0, 1, 3, 4]]
and and
...@@ -69,8 +70,8 @@ repeat = 2 ...@@ -69,8 +70,8 @@ repeat = 2
then we get then we get
Out.data = [1, 1, 2, 2, 3, 3] Out.data = [1, 2, 3, 1, 2, 3, 4, 4]
Out.lod = [[0, 2, 4, 6]] Out.lod = [[0, 6, 8], [0, 3, 6, 7, 8], [0, 1, 3, 4, 6, 7, 8]]
)DOC"); )DOC");
} }
...@@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { ...@@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
...@@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel { ...@@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
} }
}; };
class SeqExpandOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* bind = new framework::OpDescBind();
bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
bind->SetType("seq_expand_grad");
return std::unique_ptr<framework::OpDescBind>(bind);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
REGISTER_OPERATOR(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker, seq_expand_grad, ops::SeqExpandOpGrad);
ops::SeqExpandOpGradMaker);
REGISTER_OPERATOR(seq_expand_grad, ops::SeqExpandOpGrad);
REGISTER_OP_CPU_KERNEL(seq_expand, REGISTER_OP_CPU_KERNEL(seq_expand,
ops::SeqExpandKernel<paddle::platform::CPUPlace, float>); ops::SeqExpandKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h" #include "paddle/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -93,9 +94,29 @@ template <typename Place, typename T> ...@@ -93,9 +94,29 @@ template <typename Place, typename T>
class SeqExpandGradKernel : public framework::OpKernel<T> { class SeqExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X")); auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(context.GetPlace()); auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto out_lod = out->lod();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
auto d_out_dims = d_out->dims();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = framework::product(d_out_dims) / d_out_dims[0];
for (size_t i = 0; i < out->NumElements(); ++i) {
size_t ele_count = out_lod[0][i + 1] - out_lod[0][i];
size_t repeat = out->NumElements(0, i);
Eigen::TensorMap<Eigen::Tensor<const T, 2>> d_out_t(
d_out_data, static_cast<int>(repeat),
static_cast<int>((ele_count * element_len) / repeat));
Eigen::TensorMap<Eigen::Tensor<T, 1>> d_x_t(
d_x_data, static_cast<int>((ele_count * element_len) / repeat));
auto place = context.GetEigenDevice<Place>();
d_x_t.device(place) = d_out_t.sum(Eigen::array<int, 1>({0}));
d_out_data += (ele_count * element_len);
d_x_data += ((ele_count * element_len) / repeat);
}
} }
}; };
......
...@@ -246,9 +246,6 @@ class OpTest(unittest.TestCase): ...@@ -246,9 +246,6 @@ class OpTest(unittest.TestCase):
else: else:
actual = np.array(self.scope.find_var(out_name).get_tensor()) actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name] expect = self.outputs[out_name]
print "out_name: %s" % out_name
print "actual: %s" % actual
print "expcept: %s" % expect
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
actual, expect, atol=atol), actual, expect, atol=atol),
......
...@@ -62,9 +62,8 @@ class TestSeqExpand(OpTest): ...@@ -62,9 +62,8 @@ class TestSeqExpand(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self):
# def test_check_grad(self): self.check_grad(["X"], "Out")
# self.check_grad(["X"], "Out")
class TestSeqExpandCase1(TestSeqExpand): class TestSeqExpandCase1(TestSeqExpand):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册