提交 9da5192f 编写于 作者: Y Yibing Liu

adapt multiplex_op to the dev of framework

上级 18dc201b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -19,6 +18,7 @@ namespace paddle { ...@@ -19,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
class MultiplexOp : public framework::OperatorWithKernel { class MultiplexOp : public framework::OperatorWithKernel {
public: public:
...@@ -29,8 +29,12 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -29,8 +29,12 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) shouldn't be null.");
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto *out = ctx.Output<Tensor>("Out"); auto *out = ctx.Output<LoDTensor>("Out");
auto num_ins = ins.size(); auto num_ins = ins.size();
PADDLE_ENFORCE(num_ins > 2, PADDLE_ENFORCE(num_ins > 2,
"multiplex operator should have more than 2 inputs."); "multiplex operator should have more than 2 inputs.");
...@@ -53,7 +57,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,7 +57,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
MultiplexOpMaker(framework::OpProto *proto, MultiplexOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of multiplex operator.").AsDuplicable(); AddInput("X", "The input tensors of multiplex operator.").AsDuplicable();
AddOutput("Out", "The output tensor of multiplex operator."); AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator AddComment(R"DOC(Multiplex operator
...@@ -69,7 +73,7 @@ For each i-th row of output: ...@@ -69,7 +73,7 @@ For each i-th row of output:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1)
where y is the output tensor. `x_{k}` is the k-th input layer where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = x{0}[i] + 1`. and `k = x{0}[i] + 1`.
)DOC"); )DOC");
...@@ -86,13 +90,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -86,13 +90,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(),
"Input(X) should not be null");
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null."); "Input(Out@GRAD) shouldn't be null.");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X")); auto d_ins = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
for (size_t i = 0; i < ins.size(); i++) { // don;t compute gradient for index
auto dims = ins[i]->dims(); for (size_t i = 1; i < ins.size(); i++) {
d_ins[i]->Resize(dims); if (d_ins[i]) {
d_ins[i]->Resize(ins[i]->dims());
}
} }
} }
}; };
......
...@@ -18,13 +18,14 @@ namespace paddle { ...@@ -18,13 +18,14 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T> template <typename T>
class MultiplexGPUKernel : public framework::OpKernel { class MultiplexGPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[1]->dims()[0]; auto rows = ins[1]->dims()[0];
...@@ -48,10 +49,13 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -48,10 +49,13 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X")); auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
for (auto d_in : d_ins) { for (size_t i = 1; i < d_ins.size(); ++i) {
d_in->mutable_data<T>(ctx.GetPlace()); if (d_ins[i]) {
auto dims = d_in->dims(); d_ins[i]->mutable_data<T>(ctx.GetPlace());
cudaMemset(d_in->data<T>(), 0, framework::product(dims) * sizeof(T)); auto dims = d_ins[i]->dims();
cudaMemset(d_ins[i]->data<T>(), 0,
framework::product(dims) * sizeof(T));
}
} }
auto rows = ins[1]->dims()[0]; auto rows = ins[1]->dims()[0];
...@@ -62,10 +66,12 @@ class MultiplexGradGPUKernel : public framework::OpKernel { ...@@ -62,10 +66,12 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
auto index = index_t_cpu.data<T>(); auto index = index_t_cpu.data<T>();
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; int k = (int)index[i] + 1;
if (d_ins[k]) {
cudaMemcpy(d_ins[k]->data<T>() + i * cols, d_out->data<T>() + i * cols, cudaMemcpy(d_ins[k]->data<T>() + i * cols, d_out->data<T>() + i * cols,
cols * sizeof(T), cudaMemcpyDeviceToDevice); cols * sizeof(T), cudaMemcpyDeviceToDevice);
} }
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -26,7 +26,7 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -26,7 +26,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto index = ins[0]->data<T>(); auto index = ins[0]->data<T>();
...@@ -48,10 +48,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -48,10 +48,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
auto d_ins = auto d_ins =
ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X")); ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
for (auto d_in : d_ins) { for (size_t i = 1; i < d_ins.size(); i++) {
d_in->mutable_data<T>(ctx.GetPlace()); if (d_ins[i]) {
auto dims = d_in->dims(); d_ins[i]->mutable_data<T>(ctx.GetPlace());
memset(d_in->data<T>(), 0, framework::product(dims) * sizeof(T)); auto dims = d_ins[i]->dims();
memset(d_ins[i]->data<T>(), 0, framework::product(dims) * sizeof(T));
}
} }
auto index = ins[0]->data<T>(); auto index = ins[0]->data<T>();
...@@ -59,10 +61,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { ...@@ -59,10 +61,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
auto cols = ins[1]->dims()[1]; auto cols = ins[1]->dims()[1];
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int k = (int)index[i] + 1; int k = (int)index[i] + 1;
if (d_ins[k]) {
memcpy(d_ins[k]->data<T>() + i * cols, d_out->data<T>() + i * cols, memcpy(d_ins[k]->data<T>() + i * cols, d_out->data<T>() + i * cols,
cols * sizeof(T)); cols * sizeof(T));
} }
} }
}
}; };
} }
} }
...@@ -27,7 +27,16 @@ class TestMultiplexOp(OpTest): ...@@ -27,7 +27,16 @@ class TestMultiplexOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(["x1"], "Out") self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out')
def test_check_grad_ignore_x1(self):
self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1'))
def test_check_grad_ignore_x1_x2(self):
self.check_grad(['x3', 'x4'], 'Out', no_grad_set=set(['x1', 'x2']))
def test_check_grad_ignore_x3(self):
self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3'))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册