提交 31531ab5 编写于 作者: W wanghaoshuang

Add backward kernel

上级 8de04be7
......@@ -110,7 +110,7 @@ Vector<size_t> repeat_lod(Vector<size_t> data, Vector<size_t> starts,
size_t p = 0, start = 0, end = 0;
if (is_first == true) {
for (size_t i = 0; i < times.size(); ++i) {
result.push_back(data.back() + times[i] * (data[i + 1] - data[i]));
result.push_back(result.back() + times[i] * (data[i + 1] - data[i]));
}
} else {
for (size_t i = 0; i < times.size(); ++i) {
......
......@@ -60,7 +60,8 @@ As an example:
Given:
X = [1, 2 , 3]
X.data = [1, 2 , 3, 4]
X.lod = [[0, 3, 4], [0, 1, 3, 4]]
and
......@@ -69,8 +70,8 @@ repeat = 2
then we get
Out.data = [1, 1, 2, 2, 3, 3]
Out.lod = [[0, 2, 4, 6]]
Out.data = [1, 2, 3, 1, 2, 3, 4, 4]
Out.lod = [[0, 6, 8], [0, 3, 6, 7, 8], [0, 1, 3, 4, 6, 7, 8]]
)DOC");
}
......@@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
......@@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
}
};
class SeqExpandOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDescBind> Apply() const override {
auto* bind = new framework::OpDescBind();
bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs());
bind->SetType("seq_expand_grad");
return std::unique_ptr<framework::OpDescBind>(bind);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
ops::SeqExpandOpGradMaker);
REGISTER_OPERATOR(seq_expand_grad, ops::SeqExpandOpGrad);
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
seq_expand_grad, ops::SeqExpandOpGrad);
REGISTER_OP_CPU_KERNEL(seq_expand,
ops::SeqExpandKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -16,6 +16,7 @@
#include "paddle/framework/op_registry.h"
#include "paddle/memory/memcpy.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace operators {
......@@ -93,9 +94,29 @@ template <typename Place, typename T>
class SeqExpandGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(context.GetPlace());
auto* d_out = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_x = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* x = context.Input<LoDTensor>("X");
auto* out = context.Input<LoDTensor>("Out");
auto out_lod = out->lod();
d_x->set_lod(x->lod());
const T* d_out_data = d_out->data<T>();
auto d_out_dims = d_out->dims();
T* d_x_data = d_x->mutable_data<T>(context.GetPlace());
size_t element_len = framework::product(d_out_dims) / d_out_dims[0];
for (size_t i = 0; i < out->NumElements(); ++i) {
size_t ele_count = out_lod[0][i + 1] - out_lod[0][i];
size_t repeat = out->NumElements(0, i);
Eigen::TensorMap<Eigen::Tensor<const T, 2>> d_out_t(
d_out_data, static_cast<int>(repeat),
static_cast<int>((ele_count * element_len) / repeat));
Eigen::TensorMap<Eigen::Tensor<T, 1>> d_x_t(
d_x_data, static_cast<int>((ele_count * element_len) / repeat));
auto place = context.GetEigenDevice<Place>();
d_x_t.device(place) = d_out_t.sum(Eigen::array<int, 1>({0}));
d_out_data += (ele_count * element_len);
d_x_data += ((ele_count * element_len) / repeat);
}
}
};
......
......@@ -246,9 +246,6 @@ class OpTest(unittest.TestCase):
else:
actual = np.array(self.scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
print "out_name: %s" % out_name
print "actual: %s" % actual
print "expcept: %s" % expect
self.assertTrue(
np.allclose(
actual, expect, atol=atol),
......
......@@ -62,9 +62,8 @@ class TestSeqExpand(OpTest):
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.check_grad(["X"], "Out")
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestSeqExpandCase1(TestSeqExpand):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册