未验证 提交 4e8bc024 编写于 作者: Z Zhang Ting 提交者: GitHub

add fluid.device_guard to specify the device type for Op (#22254)

* add fluid.device_guard to specify the device type for Op
上级 063c51c7
......@@ -86,7 +86,8 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
AddAttr<std::vector<std::string>>(OpCreationCallstackAttrName(),
"Callstack for Op Creatation.")
.SetDefault({});
AddAttr<std::string>(OpDeviceAttrName(), "Device type of this operator.")
.SetDefault("");
Validate();
}
......
......@@ -48,6 +48,7 @@ class OpProtoAndCheckerMaker {
static const char *OpRoleVarAttrName() { return "op_role_var"; }
static const char *OpNamescopeAttrName() { return "op_namescope"; }
static const char *OpCreationCallstackAttrName() { return "op_callstack"; }
static const char *OpDeviceAttrName() { return "op_device"; }
void operator()(proto::OpProto *proto, OpAttrChecker *attr_checker);
......
......@@ -1056,6 +1056,22 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
auto expected_kernel_key = this->GetExpectedKernelType(
ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr));
if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") {
expected_kernel_key.place_ = platform::CPUPlace();
} else if (Attr<std::string>("op_device") == "gpu") {
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time.
if (SupportGPU()) {
expected_kernel_key.place_ = dev_ctx->GetPlace();
} else {
expected_kernel_key.place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
<< "Op(" << type_
<< ") has no CUDA implementation. It will be assigned to CPUPlace.";
}
}
}
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
......
......@@ -57,6 +57,8 @@ void BindConstValue(pybind11::module* m) {
op_proto_and_checker_maker.def(
"kOpCreationCallstackAttrName",
framework::OpProtoAndCheckerMaker::OpCreationCallstackAttrName);
op_proto_and_checker_maker.def(
"kOpDeviceAttrName", framework::OpProtoAndCheckerMaker::OpDeviceAttrName);
#if defined(PADDLE_WITH_DGC)
auto dgc = m->def_submodule("dgc");
dgc.def("kDGCKName", [] { return framework::details::g_dgc_k; });
......
......@@ -876,6 +876,12 @@ def _append_backward_ops_(block,
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# Set device for grad_op according to forward Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
# If input_grad_names_set is not None, extend grad_op_descs only when
# any input grad in outputs of previous grad ops.
# But this strategy is not suited for while op for some control flow,
......
......@@ -51,6 +51,7 @@ __all__ = [
'Variable',
'load_op_library',
'require_version',
'device_guard',
]
EMPTY_VAR_NAME = core.kEmptyVarName()
......@@ -61,6 +62,7 @@ CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName()
_dygraph_tracer_ = None
_dygraph_current_expected_place_ = None
_current_device = None
def require_version(min_version, max_version=None):
......@@ -1696,7 +1698,8 @@ class OpProtoHolder(object):
core.op_proto_and_checker_maker.kOpRoleAttrName(),
core.op_proto_and_checker_maker.kOpRoleVarAttrName(),
core.op_proto_and_checker_maker.kOpNameScopeAttrName(),
core.op_proto_and_checker_maker.kOpCreationCallstackAttrName()
core.op_proto_and_checker_maker.kOpCreationCallstackAttrName(),
core.op_proto_and_checker_maker.kOpDeviceAttrName()
}
......@@ -1804,6 +1807,24 @@ class Operator(object):
namescope_var_name = op_maker.kOpNameScopeAttrName()
op_attrs[namescope_var_name] = _full_name_scope()
# set device for op with kernels, give warning for op without kernels
# when force_cpu and device_guard are used at the same time, a warning will be given.
# TODO(zhangting2020): when force_cpu is removed, clear warning below.
if _current_device is not None:
if self._has_kernel(type):
op_device = op_maker.kOpDeviceAttrName()
op_attrs[op_device] = _current_device
else:
warnings.warn("The Op(%s) is not support to set device." %
type)
if 'force_cpu' in op_attrs:
if (type is 'less_than' and op_attrs['force_cpu'] != None
) or op_attrs['force_cpu'] != False:
warnings.warn(
"The Attr(force_cpu) of Op(%s) will be deprecated in the future, "
"please use 'device_guard' instead. 'device_guard' has higher priority when they are "
"used at the same time." % type)
def find_name(var_list, name):
for var_name in var_list:
if var_list[var_name] is not None and var_name == name:
......@@ -5056,3 +5077,62 @@ def load_op_library(lib_filename):
"""
core.load_op_library(lib_filename)
OpProtoHolder.instance().update_op_proto()
def switch_device(device):
global _current_device
pre_device = _current_device
_current_device = device
return pre_device
@signature_safe_contextmanager
def device_guard(device=None):
"""
**Notes**:
**The API only supports static mode.**
A context manager that specifies the device on which the OP will be placed.
Args:
device(str|None): Specify the device to use in the context. It should be 'cpu' or 'gpu',
When it is set to 'cpu' or 'gpu', all OPs created in the context will be
placed on CPUPlace or CUDAPlace. When 'gpu' is set and the program runs on
single-card, the device index will be the same as the device on which the
executor runs. Default: None, OPs in this context will be automatically
assigned devices.
Examples:
.. code-block:: python
import paddle.fluid as fluid
support_gpu = fluid.is_compiled_with_cuda()
place = fluid.CPUPlace()
if support_gpu:
place = fluid.CUDAPlace(0)
# if GPU is supported, the three OPs below will be automatically assigned to CUDAPlace(0)
data1 = fluid.layers.fill_constant(shape=[1, 3, 8, 8], value=0.5, dtype='float32')
data2 = fluid.layers.fill_constant(shape=[1, 3, 5, 5], value=0.5, dtype='float32')
shape = fluid.layers.shape(data2)
with fluid.device_guard("cpu"):
# Ops created here will be placed on CPUPlace
shape = fluid.layers.slice(shape, axes=[0], starts=[0], ends=[4])
with fluid.device_guard('gpu'):
# if GPU is supported, OPs created here will be placed on CUDAPlace(0), otherwise on CPUPlace
out = fluid.layers.crop_tensor(data1, shape=shape)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
result = exe.run(fetch_list=[out])
"""
if device not in ['cpu', 'gpu', '', None]:
raise ValueError(
"The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device)
pre_device = switch_device(device)
yield
switch_device(pre_device)
......@@ -1135,6 +1135,9 @@ def save_inference_model(dirname,
# remind user to set auc_states to zeros if the program contains auc op
all_ops = main_program.global_block().ops
for op in all_ops:
# clear device of Op
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
op._set_attr(device_attr_name, "")
if op.type == 'auc':
warnings.warn(
"please ensure that you have set the auc states to zeros before saving inference model"
......
......@@ -18,7 +18,7 @@ import numpy as np
from collections import defaultdict
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from . import framework
from . import layers
......@@ -108,6 +108,7 @@ class Optimizer(object):
self.helper = None
self._opti_name_list = []
self._accumulators_holder = {}
self._param_device_map = dict()
@framework.dygraph_only
def state_dict(self):
......@@ -405,7 +406,7 @@ class Optimizer(object):
fill_value=0.0,
shape=None,
type=None,
force_cpu=False):
device=None):
"""Utility function to add an accumulator for a parameter
Args:
......@@ -438,10 +439,11 @@ class Optimizer(object):
type=param.type if type is None else type,
shape=shape,
belong_to_optimizer=True)
self.helper.set_variable_initializer(
var,
initializer=Constant(
value=float(fill_value), force_cpu=force_cpu))
if device is None:
device = self._get_device_for_param(param.name)
with device_guard(device):
self.helper.set_variable_initializer(
var, initializer=Constant(value=float(fill_value)))
if framework.in_dygraph_mode():
if len(self._accumulators_holder) > 0:
......@@ -470,6 +472,27 @@ class Optimizer(object):
format(name, param.name))
return self._accumulators[name][param.name]
def _update_param_device_map(self, parameters_and_grads, target_block):
for param_and_grad in parameters_and_grads:
if param_and_grad[0].trainable is True:
param_name = param_and_grad[0].name
ops = target_block.ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName(
)
for op in ops:
input_arg_names = op.input_arg_names
if param_name in input_arg_names:
self._param_device_map[param_name] = op.attr(
device_attr_name)
else:
self._param_device_map[param_name] = None
def _get_device_for_param(self, param_name):
device = None
if param_name in self._param_device_map:
device = self._param_device_map[param_name]
return device
def _create_optimization_pass(self, parameters_and_grads):
"""Add optimization operators to update gradients to variables.
......@@ -505,6 +528,7 @@ class Optimizer(object):
start = len(target_block.ops)
self.helper = LayerHelper(self.__class__.__name__)
self._update_param_device_map(parameters_and_grads, target_block)
self._create_accumulators(
target_block,
[p[0] for p in parameters_and_grads if p[0].trainable])
......@@ -523,7 +547,11 @@ class Optimizer(object):
with param_and_grad[0].block.program._optimized_guard(
param_and_grad), name_scope("optimizer"):
if param_and_grad[0].trainable is True:
self._append_optimize_op(target_block, param_and_grad)
device = self._get_device_for_param(param_and_grad[0]
.name)
with device_guard(device):
optimize_op = self._append_optimize_op(
target_block, param_and_grad)
# Get custom finish ops for subclasses
# FIXME: Need to fix this once we figure out how to handle dependencies
......@@ -1793,14 +1821,14 @@ class AdamOptimizer(Optimizer):
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, force_cpu=True)
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, force_cpu=True)
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from op_test import OpTest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.core as core
import warnings
def execute(main_program, startup_program):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
exe.run(main_program)
class TestDeviceGuard(unittest.TestCase):
def test_device_guard(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
data1 = fluid.layers.fill_constant(
shape=[1, 3, 8, 8], value=0.5, dtype='float32')
data2 = fluid.layers.fill_constant(
shape=[1, 3, 5, 5], value=0.5, dtype='float32')
shape = fluid.layers.shape(data2)
with fluid.device_guard("cpu"):
shape = fluid.layers.slice(
shape, axes=[0], starts=[0], ends=[4])
with fluid.device_guard("gpu"):
out = fluid.layers.crop_tensor(data1, shape=shape)
# check if the device attr is set correctly
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
if op.type == 'slice':
self.assertEqual(op.desc.attr(device_attr_name), "cpu")
if op.type == 'crop_tensor':
self.assertEqual(op.desc.attr(device_attr_name), "gpu")
execute(main_program, startup_program)
def test_cpu_only_op(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
x = fluid.layers.fill_constant(
shape=[2, 255, 13, 13], value=0.3, dtype='float32')
gt_box = fluid.layers.fill_constant(
shape=[2, 6, 4], value=0.5, dtype='float32')
gt_label = fluid.layers.fill_constant(
shape=[2, 6], value=1.0, dtype='int32')
gt_score = fluid.layers.fill_constant(
shape=[2, 6], value=0.5, dtype='float32')
anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156,
198, 373, 326
]
anchor_mask = [0, 1, 2]
with fluid.device_guard("gpu"):
# yolov3_loss only has cpu kernel, so its cpu kernel will be executed
loss = fluid.layers.yolov3_loss(
x=x,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=anchors,
anchor_mask=anchor_mask,
class_num=80,
ignore_thresh=0.7,
downsample_ratio=32)
execute(main_program, startup_program)
def test_without_kernel_op(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
loop_len = fluid.layers.fill_constant(
shape=[1], dtype='int64', value=10)
cond = fluid.layers.less_than(x=i, y=loop_len)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with fluid.device_guard("cpu"):
while_op = fluid.layers.While(cond=cond)
with while_op.block():
i = fluid.layers.increment(x=i, value=1, in_place=True)
fluid.layers.less_than(x=i, y=loop_len, cond=cond)
assert len(w) == 1
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
if op.type == 'while':
self.assertEqual(op.desc.attr(device_attr_name), "")
execute(main_program, startup_program)
def test_error(self):
def device_attr():
with fluid.device_guard("cpu1"):
out = fluid.layers.fill_constant(
shape=[1], value=0.2, dtype='float32')
self.assertRaises(ValueError, device_attr)
def test_warning(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with fluid.device_guard("gpu"):
x = fluid.layers.fill_constant(
shape=[1], value=3.0, dtype='float32', force_cpu=True)
y = fluid.layers.fill_constant(
shape=[1], value=4.0, dtype='float32')
result = fluid.layers.less_than(x=x, y=y, force_cpu=False)
assert len(w) == 2
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
self.assertEqual(op.desc.attr(device_attr_name), "gpu")
if __name__ == '__main__':
unittest.main()
......@@ -70,7 +70,7 @@ class TestOperator(unittest.TestCase):
set([
"x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var",
"use_mkldnn", "scale_x", "scale_y", "scale_out",
"force_fp32_output", "op_namescope", "op_callstack"
"force_fp32_output", "op_namescope", "op_callstack", "op_device"
]))
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册