diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index 67e8e5f5d7914f159fca5d9789141e59e6b379c2..03559d0643ce6c75b84e7d1a08c2e9920a2a2f03 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -1,4 +1,3 @@ - /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); @@ -19,6 +18,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; class MultiplexOp : public framework::OperatorWithKernel { public: @@ -29,8 +29,12 @@ class MultiplexOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) shouldn't be null."); auto ins = ctx.MultiInput("X"); - auto *out = ctx.Output("Out"); + auto *out = ctx.Output("Out"); auto num_ins = ins.size(); PADDLE_ENFORCE(num_ins > 2, "multiplex operator should have more than 2 inputs."); @@ -53,7 +57,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { MultiplexOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensor of multiplex operator.").AsDuplicable(); + AddInput("X", "The input tensors of multiplex operator.").AsDuplicable(); AddOutput("Out", "The output tensor of multiplex operator."); AddComment(R"DOC(Multiplex operator @@ -69,7 +73,7 @@ For each i-th row of output: y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{1}.width - 1) -where y is the output tensor. `x_{k}` is the k-th input layer +where y is the output tensor. `x_{k}` is the k-th input tensor and `k = x{0}[i] + 1`. )DOC"); @@ -86,13 +90,19 @@ class MultiplexGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE(!ctx.MultiInputVar("X").empty(), + "Input(X) should not be null"); + PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(), + "Output(X@Grad) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); + auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); auto ins = ctx.MultiInput("X"); - for (size_t i = 0; i < ins.size(); i++) { - auto dims = ins[i]->dims(); - d_ins[i]->Resize(dims); + // don;t compute gradient for index + for (size_t i = 1; i < ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->Resize(ins[i]->dims()); + } } } }; diff --git a/paddle/operators/multiplex_op.cu b/paddle/operators/multiplex_op.cu index 81d637686b22410172b9cea59bcacc8e17d7a076..055e13d1834edf3c3dc7d7ee0922e58363dfeda2 100644 --- a/paddle/operators/multiplex_op.cu +++ b/paddle/operators/multiplex_op.cu @@ -18,13 +18,14 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; template class MultiplexGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto rows = ins[1]->dims()[0]; @@ -48,10 +49,13 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto* d_out = ctx.Input(framework::GradVarName("Out")); auto ins = ctx.MultiInput("X"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (auto d_in : d_ins) { - d_in->mutable_data(ctx.GetPlace()); - auto dims = d_in->dims(); - cudaMemset(d_in->data(), 0, framework::product(dims) * sizeof(T)); + for (size_t i = 1; i < d_ins.size(); ++i) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto dims = d_ins[i]->dims(); + cudaMemset(d_ins[i]->data(), 0, + framework::product(dims) * sizeof(T)); + } } auto rows = ins[1]->dims()[0]; @@ -62,8 +66,10 @@ class MultiplexGradGPUKernel : public framework::OpKernel { auto index = index_t_cpu.data(); for (auto i = 0; i < rows; i++) { int k = (int)index[i] + 1; - cudaMemcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, - cols * sizeof(T), cudaMemcpyDeviceToDevice); + if (d_ins[k]) { + cudaMemcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, + cols * sizeof(T), cudaMemcpyDeviceToDevice); + } } } }; diff --git a/paddle/operators/multiplex_op.h b/paddle/operators/multiplex_op.h index 7b627a83b3aa202edd691e1554cae0acd74f098d..82b4a2c4c75cc9c8dc11f2c8f0126d1611abc8f1 100644 --- a/paddle/operators/multiplex_op.h +++ b/paddle/operators/multiplex_op.h @@ -26,7 +26,7 @@ class MultiplexCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto ins = ctx.MultiInput("X"); - auto* out = ctx.Output("Out"); + auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); auto index = ins[0]->data(); @@ -48,10 +48,12 @@ class MultiplexGradCPUKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - for (auto d_in : d_ins) { - d_in->mutable_data(ctx.GetPlace()); - auto dims = d_in->dims(); - memset(d_in->data(), 0, framework::product(dims) * sizeof(T)); + for (size_t i = 1; i < d_ins.size(); i++) { + if (d_ins[i]) { + d_ins[i]->mutable_data(ctx.GetPlace()); + auto dims = d_ins[i]->dims(); + memset(d_ins[i]->data(), 0, framework::product(dims) * sizeof(T)); + } } auto index = ins[0]->data(); @@ -59,8 +61,10 @@ class MultiplexGradCPUKernel : public framework::OpKernel { auto cols = ins[1]->dims()[1]; for (auto i = 0; i < rows; i++) { int k = (int)index[i] + 1; - memcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, - cols * sizeof(T)); + if (d_ins[k]) { + memcpy(d_ins[k]->data() + i * cols, d_out->data() + i * cols, + cols * sizeof(T)); + } } } }; diff --git a/python/paddle/v2/framework/tests/test_multiplex_op.py b/python/paddle/v2/framework/tests/test_multiplex_op.py index c42cb6f0fe24dfa365460cce7ceb510c6cc6eb9a..f2b3881cde24c7fb96c3d7f9411352bc62d55077 100644 --- a/python/paddle/v2/framework/tests/test_multiplex_op.py +++ b/python/paddle/v2/framework/tests/test_multiplex_op.py @@ -27,7 +27,16 @@ class TestMultiplexOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["x1"], "Out") + self.check_grad(['x1', 'x2', 'x3', 'x4'], 'Out') + + def test_check_grad_ignore_x1(self): + self.check_grad(['x2', 'x3', 'x4'], 'Out', no_grad_set=set('x1')) + + def test_check_grad_ignore_x1_x2(self): + self.check_grad(['x3', 'x4'], 'Out', no_grad_set=set(['x1', 'x2'])) + + def test_check_grad_ignore_x3(self): + self.check_grad(['x1', 'x2', 'x4'], 'Out', no_grad_set=set('x3')) if __name__ == '__main__':