提交 5784e1e3 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add HasOutputProperties to check for pruned ops; Return

device name instead of casting it to a short name (GPU:0/CPU:0); VLOG(2) when printing op device placement since it is a lot of output.

PiperOrigin-RevId: 157519077
上级 2994444b
......@@ -253,6 +253,10 @@ Status GraphProperties::InferDynamically(Cluster* cluster) {
return Status::OK();
}
bool GraphProperties::HasOutputProperties(const string& name) const {
return output_properties_.find(name) != output_properties_.end();
}
std::vector<OpInfo::TensorProperties> GraphProperties::GetInputProperties(
const string& node_name) const {
auto it = input_properties_.find(node_name);
......
......@@ -37,6 +37,7 @@ class GraphProperties {
Status InferStatically();
Status InferDynamically(Cluster* cluster);
bool HasOutputProperties(const string& name) const;
std::vector<OpInfo::TensorProperties> GetInputProperties(
const string& node_name) const;
std::vector<OpInfo::TensorProperties> GetOutputProperties(
......
......@@ -184,13 +184,17 @@ void VirtualScheduler::MaybeUpdateInputProperties(
value->add_float_val(1);
inputs->push_back(control_message);
} else {
const auto input_position = NodePosition(input_source_name);
// Use the input source's output property as _Send and _Recv's input
// property.
auto outputs =
graph_properties_.GetOutputProperties(NodeName(input_source_name));
CHECK_GT(outputs.size(), input_position);
inputs->push_back(outputs[input_position]);
// Like with HasInputProperties, if a node does not have output
// properties, it's likely it was pruned during the shape inference run.
if (graph_properties_.HasOutputProperties(NodeName(input_source_name))) {
const auto input_position = NodePosition(input_source_name);
// Use the input source's output property as _Send and _Recv's input
// property.
auto outputs =
graph_properties_.GetOutputProperties(NodeName(input_source_name));
CHECK_GT(outputs.size(), input_position);
inputs->push_back(outputs[input_position]);
}
}
}
}
......@@ -211,16 +215,8 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const {
const auto* to = node_state.outputs[0];
return ChannelDeviceName(from, to);
} else {
const string& device = node->device().empty()
? "/" + default_device_type_ + ":0"
: node->device();
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
LOG(WARNING) << "Device name parse failed: " << device;
return device;
}
// Return a short name like /CPU:0 or /GPU:0.
return "/" + DeviceNameUtils::LocalName(parsed.type, parsed.id);
return node->device().empty() ? "/" + default_device_type_ + ":0"
: node->device();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册