提交 71305e5f 编写于 作者: D dzhwinter

"polish code based on comment"

上级 6f009cf8
...@@ -290,12 +290,12 @@ class ExecutionContext { ...@@ -290,12 +290,12 @@ class ExecutionContext {
return device_context_; return device_context_;
} }
//! Get variables vector with same input name. //! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const { const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name); return op_.Inputs(name);
} }
//! Get variables vector with same output name. //! Get actual name vector for this output.
const std::vector<std::string>& Outputs(const std::string& name) const { const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name); return op_.Outputs(name);
} }
......
...@@ -30,6 +30,11 @@ class NCCLInitOp : public framework::OperatorBase { ...@@ -30,6 +30,11 @@ class NCCLInitOp : public framework::OperatorBase {
"Can not find variable '%s' in the scope.", name); "Can not find variable '%s' in the scope.", name);
std::vector<int> gpus = Attr<std::vector<int>>("gpus"); std::vector<int> gpus = Attr<std::vector<int>>("gpus");
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty."); PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
}
platform::Communicator *comm = platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>(); scope.FindVar(name)->GetMutable<platform::Communicator>();
comm->InitAll(gpus); comm->InitAll(gpus);
......
...@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <functional> #include <functional>
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
...@@ -60,7 +59,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -60,7 +59,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
} else if (reduction == "ncclProd") { } else if (reduction == "ncclProd") {
reduction_op_ = ncclProd; reduction_op_ = ncclProd;
} else { } else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
...@@ -113,7 +112,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> { ...@@ -113,7 +112,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
} else if (reduction == "ncclProd") { } else if (reduction == "ncclProd") {
reduction_op_ = ncclProd; reduction_op_ = ncclProd;
} else { } else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum."); PADDLE_THROW("Invalid reduction. default ncclSum.");
} }
int root = ctx.Attr<int>("root"); int root = ctx.Attr<int>("root");
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <algorithm> #include <algorithm>
...@@ -193,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) { ...@@ -193,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
} }
} }
// ncclAReduceOp with desc // ncclReduceOp with desc
TEST_F(NCCLTester, ncclReduceOp) { TEST_F(NCCLTester, ncclReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind); std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 0; const int kRoot = 0;
...@@ -201,7 +199,7 @@ TEST_F(NCCLTester, ncclReduceOp) { ...@@ -201,7 +199,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
op2->SetInput("X", {"st"}); op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"}); op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"}); op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", {kRoot}); op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes; std::vector<f::Scope *> dev_scopes;
...@@ -241,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) { ...@@ -241,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
} }
} }
// // ncclBcastOp with desc // ncclBcastOp with desc
TEST_F(NCCLTester, ncclBcastOp) { TEST_F(NCCLTester, ncclBcastOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind); std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 5; const int kRoot = 5;
...@@ -249,7 +247,7 @@ TEST_F(NCCLTester, ncclBcastOp) { ...@@ -249,7 +247,7 @@ TEST_F(NCCLTester, ncclBcastOp) {
op2->SetInput("X", {"st"}); op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"}); op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"}); op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", {kRoot}); op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes; std::vector<f::Scope *> dev_scopes;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册