提交 0c45eab7 编写于 作者: Y Yang Yang

no getmutable nccl_com

上级 0e2deaa5
...@@ -23,7 +23,6 @@ limitations under the License. */ ...@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h" // platform::Communicator
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -54,15 +53,15 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { ...@@ -54,15 +53,15 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) { } else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>(); var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::NCCL_COM) {
var->GetMutable<platform::Communicator>();
} else if (var_type == proto::VarDesc::READER) { } else if (var_type == proto::VarDesc::READER) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarDesc::NCCL_COM) {
// GetMutable will be called in ncclInit
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type %d is not in " "Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER]", "LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]",
var_type); var_type);
} }
} }
......
...@@ -212,5 +212,5 @@ class ParallelOpTestMultipleInput(BaseParallelForTest): ...@@ -212,5 +212,5 @@ class ParallelOpTestMultipleInput(BaseParallelForTest):
fetch=['fc1.w@GRAD', 'fc2.w@GRAD', 'fc3.w@GRAD']) fetch=['fc1.w@GRAD', 'fc2.w@GRAD', 'fc3.w@GRAD'])
#if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册