未验证 提交 4a702ef3 编写于 作者: C Chen Weihang 提交者: GitHub

Support SelelctedRows allreduce in multi-cards imperative mode (#24690)

* support selectedrows allreduce in multi-cards dygraph, test=develop

* remove useless import modules in unittests, test=develop

* add nccl cmake to get nccl version, test=develop

* add if-condition to compiled correctly, test=develop

* add detail version parseing for old nccl, test=develop

* polish camke details, test=develop

* fix remove test cmake error, test=develop

* fix cmake condition, test=develop

* change unittest camke list, test=develop

* fix unittest cmake rule, test=develop, test=framep0
上级 14b85405
......@@ -143,6 +143,7 @@ endif()
if(WITH_NCCL)
add_definitions("-DPADDLE_WITH_NCCL")
include(nccl)
else()
if(WITH_GPU)
MESSAGE(WARNING "If the environment is multi-card, the WITH_NCCL option needs to be turned on, otherwise only a single card can be used.")
......
if(NOT WITH_GPU)
return()
endif()
# Now we don't support NCCL on windows
if(WIN32)
return()
endif()
set(NCCL_ROOT "/usr" CACHE PATH "NCCL ROOT")
find_path(NCCL_INCLUDE_DIR nccl.h
PATHS ${NCCL_ROOT} ${NCCL_ROOT}/include ${NCCL_ROOT}/local/include
$ENV{NCCL_ROOT} $ENV{NCCL_ROOT}/include $ENV{NCCL_ROOT}/local/include
NO_DEFAULT_PATH
)
if(WITH_NCCL)
file(READ ${NCCL_INCLUDE_DIR}/nccl.h NCCL_VERSION_FILE_CONTENTS)
string(REGEX MATCH "define NCCL_VERSION_CODE +([0-9]+)"
NCCL_VERSION "${NCCL_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define NCCL_VERSION_CODE +([0-9]+)" "\\1"
NCCL_VERSION "${NCCL_VERSION}")
if("${NCCL_VERSION}" GREATER "2000")
message(STATUS "Current NCCL header is ${NCCL_INCLUDE_DIR}/nccl.h. "
"Current NCCL version is v${NCCL_VERSION}. ")
else()
# in old version nccl, it may not define NCCL_VERSION_CODE
string(REGEX MATCH "define NCCL_MAJOR +([0-9]+)" NCCL_MAJOR_VERSION
"${NCCL_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define NCCL_MAJOR +([0-9]+)" "\\1"
NCCL_MAJOR_VERSION "${NCCL_MAJOR_VERSION}")
string(REGEX MATCH "define NCCL_MINOR +([0-9]+)" NCCL_MINOR_VERSION
"${NCCL_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define NCCL_MINOR +([0-9]+)" "\\1"
NCCL_MINOR_VERSION "${NCCL_MINOR_VERSION}")
string(REGEX MATCH "define NCCL_PATCH +([0-9]+)"
NCCL_PATCH_VERSION "${NCCL_VERSION_FILE_CONTENTS}")
string(REGEX REPLACE "define NCCL_PATCH +([0-9]+)" "\\1"
NCCL_PATCH_VERSION "${NCCL_PATCH_VERSION}")
if(NOT NCCL_MAJOR_VERSION)
set(NCCL_VERSION "0")
else()
math(EXPR NCCL_VERSION
"${NCCL_MAJOR_VERSION} * 1000 +
${NCCL_MINOR_VERSION} * 100 + ${NCCL_PATCH_VERSION}")
endif()
add_definitions("-DNCCL_VERSION_CODE=$NCCL_VERSION")
message(STATUS "Current NCCL header is ${NCCL_INCLUDE_DIR}/nccl.h. "
"Current NCCL version is v${NCCL_MAJOR_VERSION}.${NCCL_MINOR_VERSION}.${NCCL_PATCH_VERSION} ")
endif()
endif()
......@@ -11,7 +11,8 @@ cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradien
cc_library(imperative_profiler SRCS profiler.cc)
if(NOT WIN32)
if(WITH_NCCL)
cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context)
cc_library(imperative_all_reduce SRCS all_reduce.cc DEPS collective_helper device_context selected_rows tensor)
cc_library(nccl_context SRCS nccl_context.cc DEPS collective_helper device_context imperative_all_reduce)
endif()
cc_library(data_loader SRCS data_loader.cc DEPS enforce)
endif(NOT WIN32)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_NCCL
#include "paddle/fluid/imperative/all_reduce.h"
#include <string>
#include <utility>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/nccl_helper.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace imperative {
static void AllReduce(const framework::Tensor &src, framework::Tensor *dst,
const ParallelStrategy &strategy, cudaStream_t stream) {
const auto &place = src.place();
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), true,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));
const void *src_ptr = src.data<void>();
dst->Resize(src.dims());
auto *dst_ptr = dst->mutable_data(src.place(), src.type());
auto nccl_dtype = platform::ToNCCLDataType(src.type());
auto comm = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place))
->nccl_comm();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
src_ptr, dst_ptr, src.numel(), nccl_dtype, ncclSum, comm, stream));
}
#if NCCL_VERSION_CODE >= 2212
static void AllReduce(const framework::SelectedRows &src,
framework::SelectedRows *dst,
const ParallelStrategy &strategy, cudaStream_t stream) {
VLOG(0) << "SelectedRows AllReduce start";
const auto &src_tensor = src.value();
const auto &place = src_tensor.place();
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), true,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));
auto dtype = src_tensor.type();
auto nccl_dtype = platform::ToNCCLDataType(dtype);
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
auto comm = dev_ctx->nccl_comm();
// 1. Gather rows number from all workers. Here use ncclAllGather to do this,
// but we can use other ways to implement is in the future
const auto &src_rows = src.rows();
framework::Vector<int64_t> rows_num_vector(strategy.nranks_);
rows_num_vector[strategy.local_rank_] = static_cast<int64_t>(src_rows.size());
auto *gpu_rows_num_ptr = rows_num_vector.CUDAMutableData(place);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllGather(
gpu_rows_num_ptr + strategy.local_rank_, gpu_rows_num_ptr, 1, ncclInt64,
comm, stream));
if (stream != dev_ctx->stream()) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream));
}
const auto *cpu_rows_num_ptr = rows_num_vector.data();
auto rows_num =
std::accumulate(cpu_rows_num_ptr, cpu_rows_num_ptr + strategy.nranks_,
static_cast<int64_t>(0));
dst->set_height(src.height());
VLOG(0) << "Gather rows: " << string::join_strings(rows_num_vector, ',')
<< ", total rows number: " << rows_num
<< ", height: " << src.height();
PADDLE_ENFORCE_LE(
rows_num, src.height(),
platform::errors::Unimplemented(
"The gathered SelectedRows's rows number should less than or equal "
"to the SelectedRows's height, but the actual rows number is %d, the "
"SelectedRows's height is %d.",
rows_num, src.height()));
auto *dst_rows = dst->mutable_rows();
dst_rows->resize(rows_num);
auto *dst_rows_ptr = dst_rows->CUDAMutableData(place);
const auto *src_rows_ptr = src_rows.CUDAData(place);
auto *dst_tensor = dst->mutable_value();
auto dims = src_tensor.dims();
dims[0] = rows_num;
auto feature_size = framework::product(dims) / dims[0];
dst_tensor->Resize(dims);
auto *dst_tensor_ptr = dst_tensor->mutable_data(place, dtype);
const auto *src_tensor_ptr = src_tensor.data<void>();
auto sizeof_dtype = framework::SizeOfType(dtype);
int64_t row_offset = 0;
for (int i = 0; i < strategy.nranks_; ++i) {
if (cpu_rows_num_ptr[i] > 0) {
// 2. Broadcast the rows of SelectedRows
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
src_rows_ptr, dst_rows_ptr + row_offset, cpu_rows_num_ptr[i],
ncclInt64, i, comm, stream));
// 3. Broadcast the tensor data of SelectedRows
auto *dst_tensor_ptr_i = reinterpret_cast<uint8_t *>(dst_tensor_ptr) +
row_offset * feature_size * sizeof_dtype;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
src_tensor_ptr, dst_tensor_ptr_i, cpu_rows_num_ptr[i] * feature_size,
nccl_dtype, i, comm, stream));
row_offset += cpu_rows_num_ptr[i];
}
}
VLOG(0) << "Original SelectedRows rows: "
<< string::join_strings(src_rows, ',');
VLOG(0) << "Result SelectedRows rows: "
<< string::join_strings(*dst_rows, ',');
}
#endif
void AllReduce(const framework::Variable &src, framework::Variable *dst,
const ParallelStrategy &strategy, cudaStream_t stream) {
if (src.IsType<framework::LoDTensor>()) {
if (!dst->IsType<framework::LoDTensor>()) {
dst->Clear();
}
AllReduce(src.Get<framework::LoDTensor>(),
dst->GetMutable<framework::LoDTensor>(), strategy, stream);
#if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<framework::SelectedRows>()) {
if (&src != dst) {
if (!dst->IsType<framework::SelectedRows>()) {
dst->Clear();
}
AllReduce(src.Get<framework::SelectedRows>(),
dst->GetMutable<framework::SelectedRows>(), strategy, stream);
} else {
// SelectedRows cannot be allreduce in-place
framework::Variable tmp_dst;
AllReduce(src.Get<framework::SelectedRows>(),
tmp_dst.GetMutable<framework::SelectedRows>(), strategy,
stream);
*dst = std::move(tmp_dst);
}
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported variable type %s for imperative allreduce, only "
"LoDTensor and SelectedRows are supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}
static const platform::Place &GetVarPlace(const framework::Variable &src) {
if (src.IsType<framework::LoDTensor>()) {
return src.Get<framework::LoDTensor>().place();
#if NCCL_VERSION_CODE >= 2212
} else if (src.IsType<framework::SelectedRows>()) {
return src.Get<framework::SelectedRows>().value().place();
#endif
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Cannot get unsupported variable type %s for imperative allreduce, "
"only "
"LoDTensor and SelectedRows are supported.",
platform::demangle(framework::ToTypeName(src.Type()))));
}
}
void AllReduce(const framework::Variable &src, framework::Variable *dst,
const ParallelStrategy &strategy) {
const auto &place = GetVarPlace(src);
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(place), true,
platform::errors::Unimplemented(
"Imperative mode does not support multi-CPU training yet."));
auto *dev_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream();
AllReduce(src, dst, strategy, stream);
}
} // namespace imperative
} // namespace paddle
#endif
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_NCCL
#include <cuda.h>
#include <cuda_runtime.h>
#include <nccl.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/nccl_context.h"
namespace paddle {
namespace imperative {
void AllReduce(const framework::Variable &src, framework::Variable *dst,
const ParallelStrategy &strategy);
} // namespace imperative
} // namespace paddle
#endif
......@@ -71,6 +71,11 @@ extern void* nccl_dso_handle;
NCCL_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#if NCCL_VERSION_CODE >= 2212
#define NCCL_RAND_ROUTINE_EACH_AFTER_2212(__macro) __macro(ncclBroadcast);
NCCL_RAND_ROUTINE_EACH_AFTER_2212(DECLARE_DYNAMIC_LOAD_NCCL_WRAP)
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -25,6 +25,7 @@ limitations under the License. */
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/all_reduce.h"
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/data_loader.h"
......@@ -758,6 +759,36 @@ void BindImperative(py::module *m_ptr) {
return std::shared_ptr<imperative::VarBase>(nullptr);
},
py::return_value_policy::copy)
.def("_is_sparse",
[](imperative::VarBase &self) {
return self.Var().IsType<framework::SelectedRows>();
})
.def("_allreduce",
[](imperative::VarBase &self,
const imperative::ParallelStrategy &strategy) {
if (strategy.nranks_ > 1) {
#ifdef PADDLE_WITH_NCCL
#if NCCL_VERSION_CODE >= 2212
imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
#else
if (!self.Var().IsType<framework::SelectedRows>()) {
imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Imperative SelectedRows allreduce is not supported when "
"paddle is compiled with NCCL verison lower than v2.2.12. "
"You can set is_sparse=False for the Layer containing "
"this argument, such as Embedding(is_sparse=False)."));
}
#endif // NCCL_VERSION_CODE
#else
PADDLE_THROW(platform::errors::Unimplemented(
"Imperative allreduce is not supported when paddle is "
"not compiled with NCCL."));
#endif // PADDLE_WITH_NCCL
}
},
py::call_guard<py::gil_scoped_release>())
.def("_copy_to",
[](const imperative::VarBase &self, const platform::CPUPlace &place,
bool blocking) { return self.NewVarBase(place, blocking); },
......
......@@ -19,7 +19,6 @@ from .. import core
from . import layers
from . import parallel_helper
from .. import framework
from ..layers import collective
from . import to_variable, no_grad
__all__ = ["prepare_context", "ParallelEnv", "DataParallel"]
......@@ -421,14 +420,23 @@ class DataParallel(layers.Layer):
grad_var_set = set()
grad_vars = []
sparse_grad_vars = []
for param in self._layers.parameters():
# NOTE(zcd): The grad_ivar maybe no generated.
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
if g_var._is_sparse():
sparse_grad_vars.append(g_var)
continue
grad_vars.append(g_var)
assert g_var not in grad_var_set
grad_var_set.add(g_var)
if sparse_grad_vars:
sparse_grad_vars.sort(key=lambda x: x.name)
for grad_var in sparse_grad_vars:
grad_var._allreduce(self._strategy)
# FIXME(zcd): the type of the var should be LoDTensor, i.e
# the gradients should be dense, otherwise, the following
# logic should be updated.
......@@ -450,9 +458,8 @@ class DataParallel(layers.Layer):
coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups)
for coalesced_grad, g_vars, g_shapes in coalesced_grads_and_vars:
collective._allreduce(
coalesced_grad, coalesced_grad, sync_mode=False)
for coalesced_grad, _, _ in coalesced_grads_and_vars:
coalesced_grad._allreduce(self._strategy)
self._split_tensors(coalesced_grads_and_vars)
......
......@@ -10,6 +10,8 @@ endif()
string(REPLACE ".py" "" DIST_TEST_OPS "${DIST_TEST_OPS}")
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mnist)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_transformer)
list(APPEND DIST_TEST_OPS test_listen_and_serv_op)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests.
......@@ -57,10 +59,20 @@ if (NOT ${WITH_GPU})
LIST(REMOVE_ITEM TEST_OPS test_rank_attention_op) # TODO(shenliang03): rank_attention_op support CPU device in future
LIST(REMOVE_ITEM TEST_OPS test_batch_fc_op) # TODO(shenliang03): batch_fc_op support CPU device in future
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_se_resnext)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sparse_embedding)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_transformer)
elseif(${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
endif()
if (WITH_NCCL)
if (${NCCL_VERSION} VERSION_LESS 2212)
LIST(REMOVE_ITEM DIST_TEST_OPS test_parallel_dygraph_sparse_embedding)
LIST(REMOVE_ITEM DIST_TEST_OPS test_parallel_dygraph_transformer)
endif()
endif()
if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_pipeline)
LIST(REMOVE_ITEM TEST_OPS test_boxps)
......@@ -176,7 +188,6 @@ function(bash_test_modules TARGET_NAME)
endfunction()
list(REMOVE_ITEM TEST_OPS test_warpctc_op)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_crf)
list(REMOVE_ITEM TEST_OPS test_data_norm_op)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_feed)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Embedding
from paddle.fluid.dygraph.base import to_variable
from test_dist_base import runtime_main, TestParallelDyGraphRunnerBase
class SimpleNet(fluid.Layer):
def __init__(self,
hidden_size,
vocab_size,
num_steps=20,
init_scale=0.1,
is_sparse=False):
super(SimpleNet, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.init_scale = init_scale
self.num_steps = num_steps
self.embedding = Embedding(
size=[self.vocab_size, self.hidden_size],
dtype='float32',
is_sparse=is_sparse,
param_attr=fluid.ParamAttr(
name='embedding_param',
initializer=fluid.initializer.UniformInitializer(
low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter(
attr=fluid.ParamAttr(),
shape=[self.hidden_size, self.vocab_size],
dtype="float32",
default_initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale))
self.softmax_bias = self.create_parameter(
attr=fluid.ParamAttr(),
shape=[self.vocab_size],
dtype="float32",
default_initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale))
def forward(self, input, label):
x_emb = self.embedding(input)
fc = fluid.layers.matmul(x_emb, self.softmax_weight)
fc = fluid.layers.elementwise_add(fc, self.softmax_bias)
projection = fluid.layers.reshape(fc, shape=[-1, self.vocab_size])
loss = fluid.layers.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=False)
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps])
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss)
return loss
# global configs
batch_size = 4
batch_num = 200
hidden_size = 10
vocab_size = 1000
num_steps = 3
init_scale = 0.1
def fake_sample_reader():
def __reader__():
for i in range(batch_num):
x_data = np.arange(num_steps).astype('int64')
y_data = np.arange(1, 1 + num_steps).astype('int64')
yield x_data, y_data
return __reader__
class TestSparseEmbedding(TestParallelDyGraphRunnerBase):
def get_model(self):
model = SimpleNet(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_steps=num_steps,
init_scale=init_scale,
is_sparse=True)
train_reader = paddle.batch(
fake_sample_reader(), batch_size=batch_size, drop_last=True)
optimizer = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters())
return model, train_reader, optimizer
def run_one_loop(self, model, optimizer, batch):
x_data = np.array([x[0].reshape(3) for x in batch]).astype('int64')
y_data = np.array([x[1].reshape(3) for x in batch]).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
x = to_variable(x_data)
y = to_variable(y_data)
dy_loss = model(x, y)
return dy_loss
if __name__ == "__main__":
runtime_main(TestSparseEmbedding)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphSparseEmdedding(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_sparse_embedding(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_sparse_embedding.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
from test_dist_base import TestDistBase
flag_name = os.path.splitext(__file__)[0]
class TestParallelDygraphTransformer(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._nccl2_mode = True
self._dygraph = True
def test_transformer(self):
if fluid.core.is_compiled_with_cuda():
self.check_with_place(
"parallel_dygraph_transformer.py",
delta=1e-5,
check_error_log=True,
log_name=flag_name)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册