未验证 提交 95c0c126 编写于 作者: R ranqiu92 提交者: GitHub

Merge pull request #7384 from dzhwinter/feature/sync_wait

Feature/sync wait
...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
...@@ -21,6 +22,10 @@ limitations under the License. */ ...@@ -21,6 +22,10 @@ limitations under the License. */
#include "paddle/framework/shape_inference.h" #include "paddle/framework/shape_inference.h"
#include "paddle/framework/var_type.h" #include "paddle/framework/var_type.h"
DEFINE_bool(op_sync, false,
"Default cuda is asynchronous device, set to True will"
"force op run in synchronous mode.");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -542,8 +547,14 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -542,8 +547,14 @@ void OperatorWithKernel::Run(const Scope& scope,
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
kernel_iter->second->Compute(ExecutionContext( auto* new_dev_ctx = pool.Get(expected_kernel_key.place_);
*this, new_scope, *pool.Get(expected_kernel_key.place_))); kernel_iter->second->Compute(
ExecutionContext(*this, new_scope, *new_dev_ctx));
/*For profiling/benchmark only*/
if (FLAGS_op_sync) {
new_dev_ctx->Wait();
}
} }
proto::DataType OperatorWithKernel::IndicateDataType( proto::DataType OperatorWithKernel::IndicateDataType(
......
...@@ -58,7 +58,7 @@ def __bootstrap__(): ...@@ -58,7 +58,7 @@ def __bootstrap__():
read_env_flags = ['use_pinned_memory', 'check_nan_inf'] read_env_flags = ['use_pinned_memory', 'check_nan_inf']
if core.is_compile_gpu(): if core.is_compile_gpu():
read_env_flags.append('fraction_of_gpu_memory_to_use') read_env_flags += ['fraction_of_gpu_memory_to_use', 'op_sync']
core.init_gflags([sys.argv[0]] + core.init_gflags([sys.argv[0]] +
["--tryfromenv=" + ",".join(read_env_flags)]) ["--tryfromenv=" + ",".join(read_env_flags)])
core.init_glog(sys.argv[0]) core.init_glog(sys.argv[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册