未验证 提交 a6b089c6 编写于 作者: H hutuxian 提交者: GitHub

add macro to ban windows (#21422)

remove nccl related code in windows
上级 ebfb720a
...@@ -16,9 +16,11 @@ limitations under the License. */ ...@@ -16,9 +16,11 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/data_norm_op.h" #include "paddle/fluid/operators/data_norm_op.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -176,6 +178,7 @@ class DataNormGradKernel<platform::CUDADeviceContext, T> ...@@ -176,6 +178,7 @@ class DataNormGradKernel<platform::CUDADeviceContext, T>
d_batch_sum, d_batch_square_sum); d_batch_sum, d_batch_square_sum);
if (need_sync_stats) { if (need_sync_stats) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace()); auto comm = platform::NCCLCommContext::Instance().Get(0, ctx.GetPlace());
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce(
reinterpret_cast<const void *>(d_batch_size), reinterpret_cast<const void *>(d_batch_size),
...@@ -194,7 +197,13 @@ class DataNormGradKernel<platform::CUDADeviceContext, T> ...@@ -194,7 +197,13 @@ class DataNormGradKernel<platform::CUDADeviceContext, T>
LOG(FATAL) << "Fail to sync nccl stream: " LOG(FATAL) << "Fail to sync nccl stream: "
<< cudaGetErrorString(e_sync); << cudaGetErrorString(e_sync);
} }
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU, and need_sync_stats connot be "
"supported on windows now."));
#endif
} }
T *batch_size_data = T *batch_size_data =
ctx.Output<Tensor>("BatchSize")->mutable_data<T>(ctx.GetPlace()); ctx.Output<Tensor>("BatchSize")->mutable_data<T>(ctx.GetPlace());
T *batch_sum_data = T *batch_sum_data =
......
...@@ -287,6 +287,11 @@ class TestDataNormOpWithSyncStats(OpTest): ...@@ -287,6 +287,11 @@ class TestDataNormOpWithSyncStats(OpTest):
def test_sync_stats(self): def test_sync_stats(self):
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
return return
if os.name == 'nt':
print(
'Skip TestDataNormOpWithSyncStats because nccl is not supported on windows'
)
return
x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0) x = fluid.layers.data(name='x', shape=[1], dtype='int64', lod_level=0)
emb = layers.embedding( emb = layers.embedding(
input=x, input=x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册