diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 6a61a8d78614472965638ce42447ab97b7a62944..cb492f999532fff3562050135c3a1abdcda06ad5 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -37,7 +37,7 @@ else() variable_response.cc collective_client.cc collective_server.cc ${BRPC_SRCS} - PROTO ${CMAKE_CURRENT_BINARY_DIR}/send_recv.proto + PROTO send_recv.proto DEPS lod_tensor selected_rows memory) set(RPC_DEPS sendrecvop_rpc brpc ssl crypto protobuf leveldb snappystream snappy zlib) diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index 6e460c470be71bfaaa37b4ef796027c2e2b9e376..3bf8586254e9867c7f5151178db866655df11535 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -21,20 +21,20 @@ namespace operators { enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; -#define CHECK_CASE(i, flags, kernel_name, args...) \ - if (i == flags) { \ - kernel_name<<>>(args); \ +#define CHECK_CASE(i, flags, kernel_name, ...) \ + if (i == flags) { \ + kernel_name<<>>(__VA_ARGS__); \ } // 0 for no scale, no bias // 1 for has scale, no bias // 2 for no scale, has bias // 3 for has scale, has bias -#define UNROLL_ALL_CASES(flags, kernel_name, args...) \ - CHECK_CASE(0, flags, kernel_name, args) \ - CHECK_CASE(1, flags, kernel_name, args) \ - CHECK_CASE(2, flags, kernel_name, args) \ - CHECK_CASE(3, flags, kernel_name, args) +#define UNROLL_ALL_CASES(flags, kernel_name, ...) \ + CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \ + CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \ + CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \ + CHECK_CASE(3, flags, kernel_name, __VA_ARGS__) template __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 20aa6054fe4b7d6a1b1454292954237bdfbe045e..d3ff14a17955990bff851e95bd61fbc370ea7aa5 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -305,7 +305,9 @@ class Executor(object): def __init__(self, place): self.place = place self.program_caches = dict() - self.executor = None + p = core.Place() + p.set_place(self.place) + self._default_executor = core.Executor(p) self._closed = False def _get_program_cache(self, program_cache_key): @@ -397,12 +399,13 @@ class Executor(object): >>> ... >>> exe.close() """ - if not self._closed and self.executor: - self.executor.close() + if not self._closed: + self._default_executor.close() self._closed = True def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, return_numpy): + exe = program._executor if isinstance(feed, dict): feed_tensor_dict = dict() for feed_name in feed: @@ -414,8 +417,7 @@ class Executor(object): feed_tensor.set(feed[feed_name], core.CPUPlace()) feed_tensor_dict[feed_name] = feed_tensor - self.executor.feed_and_split_tensor_into_local_scopes( - feed_tensor_dict) + exe.feed_and_split_tensor_into_local_scopes(feed_tensor_dict) elif isinstance(feed, list) or isinstance(feed, tuple): if len(feed) != len(program._places): raise ValueError( @@ -436,10 +438,10 @@ class Executor(object): tensor = tmp res_dict[feed_name] = tensor res.append(res_dict) - self.executor.feed_tensors_into_local_scopes(res) + exe.feed_tensors_into_local_scopes(res) fetch_var_names = list(map(_to_name_str, fetch_list)) - self.executor.run(fetch_var_names, fetch_var_name) + exe.run(fetch_var_names, fetch_var_name) arr = scope.find_var(fetch_var_name).get_lod_tensor_array() if return_numpy: @@ -511,12 +513,9 @@ class Executor(object): compiled = isinstance(program, compiler.CompiledProgram) # For backward compatibility, run directly. if not compiled: - if not self.executor: - p = core.Place() - p.set_place(self.place) - self.executor = core.Executor(p) return self._run( program, + self._default_executor, feed=feed, fetch_list=fetch_list, feed_var_name=feed_var_name, @@ -526,7 +525,6 @@ class Executor(object): use_program_cache=use_program_cache) program._compile(scope, self.place) - self.executor = program._executor if program._is_data_parallel: return self._run_parallel( program, @@ -536,12 +534,13 @@ class Executor(object): fetch_var_name=fetch_var_name, return_numpy=return_numpy) elif program._is_inference: - return self._run_inference(program, feed) + return self._run_inference(program._executor, feed) else: # TODO(panyx0718): Can compile program to optimize executor # performance. return self._run( program._program, + self._default_executor, feed=feed, fetch_list=fetch_list, feed_var_name=feed_var_name, @@ -550,8 +549,8 @@ class Executor(object): return_numpy=return_numpy, use_program_cache=use_program_cache) - def _run(self, program, feed, fetch_list, feed_var_name, fetch_var_name, - scope, return_numpy, use_program_cache): + def _run(self, program, exe, feed, fetch_list, feed_var_name, + fetch_var_name, scope, return_numpy, use_program_cache): if feed is None: feed = {} @@ -589,11 +588,11 @@ class Executor(object): fetch_var_name=fetch_var_name) self._feed_data(program, feed, feed_var_name, scope) - self.executor.run(program.desc, scope, 0, True, True) + exe.run(program.desc, scope, 0, True, True) outs = self._fetch_data(fetch_list, fetch_var_name, scope) if return_numpy: outs = as_numpy(outs) return outs - def _run_inference(self, program, feed): - return self.executor.run(feed) + def _run_inference(self, exe, feed): + return exe.run(feed)