提交 71305e5f 编写于 作者: D dzhwinter

"polish code based on comment"

上级 6f009cf8
......@@ -290,12 +290,12 @@ class ExecutionContext {
return device_context_;
}
//! Get variables vector with same input name.
//! Get actual name vector for this input.
const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}
//! Get variables vector with same output name.
//! Get actual name vector for this output.
const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}
......
......@@ -30,6 +30,11 @@ class NCCLInitOp : public framework::OperatorBase {
"Can not find variable '%s' in the scope.", name);
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
if (scope.FindVar(name) == nullptr) {
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
}
platform::Communicator *comm =
scope.FindVar(name)->GetMutable<platform::Communicator>();
comm->InitAll(gpus);
......
......@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include <functional>
#include "paddle/framework/lod_tensor.h"
......@@ -60,7 +59,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
PADDLE_THROW("Invalid reduction. default ncclSum.");
}
auto* comm = ctx.Input<Communicator>("Communicator");
......@@ -113,7 +112,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
} else if (reduction == "ncclProd") {
reduction_op_ = ncclProd;
} else {
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
PADDLE_THROW("Invalid reduction. default ncclSum.");
}
int root = ctx.Attr<int>("root");
......
......@@ -12,8 +12,6 @@
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <algorithm>
......@@ -193,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
}
}
// ncclAReduceOp with desc
// ncclReduceOp with desc
TEST_F(NCCLTester, ncclReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 0;
......@@ -201,7 +199,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", {kRoot});
op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes;
......@@ -241,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
}
}
// // ncclBcastOp with desc
// ncclBcastOp with desc
TEST_F(NCCLTester, ncclBcastOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
const int kRoot = 5;
......@@ -249,7 +247,7 @@ TEST_F(NCCLTester, ncclBcastOp) {
op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"});
op2->SetOutput("Out", {"rt"});
op2->SetAttr("root", {kRoot});
op2->SetAttr("root", kRoot);
std::vector<f::Scope *> dev_scopes;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册