From 71305e5f90f87dcdf6fc0ab619f41da1763e74c7 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sun, 29 Oct 2017 13:50:34 -0700 Subject: [PATCH] "polish code based on comment" --- paddle/framework/operator.h | 4 ++-- paddle/operators/nccl_op.cc | 5 +++++ paddle/operators/nccl_op.cu | 5 ++--- paddle/operators/nccl_op_test.cu | 10 ++++------ 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 323625036..a2544f1dc 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -290,12 +290,12 @@ class ExecutionContext { return device_context_; } - //! Get variables vector with same input name. + //! Get actual name vector for this input. const std::vector& Inputs(const std::string& name) const { return op_.Inputs(name); } - //! Get variables vector with same output name. + //! Get actual name vector for this output. const std::vector& Outputs(const std::string& name) const { return op_.Outputs(name); } diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index 3744d1b47..d39cb2fcf 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -30,6 +30,11 @@ class NCCLInitOp : public framework::OperatorBase { "Can not find variable '%s' in the scope.", name); std::vector gpus = Attr>("gpus"); PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty."); + + if (scope.FindVar(name) == nullptr) { + PADDLE_THROW("Output(Communicator) is needed for ncclInit operator."); + } + platform::Communicator *comm = scope.FindVar(name)->GetMutable(); comm->InitAll(gpus); diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index f8b3b8a8b..86dee8ee8 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include #include "paddle/framework/lod_tensor.h" @@ -60,7 +59,7 @@ class NCCLAllReduceKernel : public framework::OpKernel { } else if (reduction == "ncclProd") { reduction_op_ = ncclProd; } else { - PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum."); + PADDLE_THROW("Invalid reduction. default ncclSum."); } auto* comm = ctx.Input("Communicator"); @@ -113,7 +112,7 @@ class NCCLReduceKernel : public framework::OpKernel { } else if (reduction == "ncclProd") { reduction_op_ = ncclProd; } else { - PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum."); + PADDLE_THROW("Invalid reduction. default ncclSum."); } int root = ctx.Attr("root"); diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu index 63a286f60..80c50a28a 100644 --- a/paddle/operators/nccl_op_test.cu +++ b/paddle/operators/nccl_op_test.cu @@ -12,8 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU - #include #include #include @@ -193,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) { } } -// ncclAReduceOp with desc +// ncclReduceOp with desc TEST_F(NCCLTester, ncclReduceOp) { std::unique_ptr op2(new f::OpDescBind); const int kRoot = 0; @@ -201,7 +199,7 @@ TEST_F(NCCLTester, ncclReduceOp) { op2->SetInput("X", {"st"}); op2->SetInput("Communicator", {"comm"}); op2->SetOutput("Out", {"rt"}); - op2->SetAttr("root", {kRoot}); + op2->SetAttr("root", kRoot); std::vector dev_scopes; @@ -241,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) { } } -// // ncclBcastOp with desc +// ncclBcastOp with desc TEST_F(NCCLTester, ncclBcastOp) { std::unique_ptr op2(new f::OpDescBind); const int kRoot = 5; @@ -249,7 +247,7 @@ TEST_F(NCCLTester, ncclBcastOp) { op2->SetInput("X", {"st"}); op2->SetInput("Communicator", {"comm"}); op2->SetOutput("Out", {"rt"}); - op2->SetAttr("root", {kRoot}); + op2->SetAttr("root", kRoot); std::vector dev_scopes; -- GitLab