未验证 提交 b31647c6 编写于 作者: Q Qiyang Min 提交者: GitHub

Merge branch 'develop' into update_simple_distranspiler

...@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) { ...@@ -144,8 +144,7 @@ TEST(paddle_inference_api_impl, image_classification) {
float* data = static_cast<float*>(outputs[0].data.data); float* data = static_cast<float*>(outputs[0].data.data);
float* lod_data = output1.data<float>(); float* lod_data = output1.data<float>();
for (size_t j = 0; j < len / sizeof(float); ++j) { for (size_t j = 0; j < len / sizeof(float); ++j) {
EXPECT_LT(lod_data[j] - data[j], 1e-10); EXPECT_NEAR(lod_data[j], data[j], 1e-3);
EXPECT_GT(lod_data[j] - data[j], -1e-10);
} }
free(data); free(data);
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import math import math
import unittest import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_dense_variable from paddle.fluid.transpiler.distribute_transpiler import split_variable
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import random import random
...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR, # dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape) shape=shape)
var_list.append(var) var_list.append(var)
blocks = split_dense_variable(var_list, 10, min_size) blocks = split_variable(var_list, 10, min_size)
all_sizes = [] all_sizes = []
for s in expected_sizes: for s in expected_sizes:
for s2 in s: for s2 in s:
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from program_utils import *
from ufind import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def delete_ops(block, ops):
try:
start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e:
raise e
block.program.sync_with_cpp()
def find_op_by_input_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.input_arg_names:
return index
return -1
def find_op_by_output_arg(block, arg_name):
for index, op in enumerate(block.ops):
if arg_name in op.output_arg_names:
return index
return -1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class UnionFind(object):
""" Union-find data structure.
Union-find is a data structure that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets.
Reference:
https://en.wikipedia.org/wiki/Disjoint-set_data_structure
Args:
elements(list): The initialize element list.
"""
def __init__(self, elementes=None):
self._parents = [] # index -> parent index
self._index = {} # element -> index
self._curr_idx = 0
if not elementes:
elementes = []
for ele in elementes:
self._parents.append(self._curr_idx)
self._index.update({ele: self._curr_idx})
self._curr_idx += 1
def find(self, x):
# Find the root index of given element x,
# execute the path compress while findind the root index
if not x in self._index:
return -1
idx = self._index[x]
while idx != self._parents[idx]:
t = self._parents[idx]
self._parents[idx] = self._parents[t]
idx = t
return idx
def union(self, x, y):
# Union two given element
x_root = self.find(x)
y_root = self.find(y)
if x_root == y_root:
return
self._parents[x_root] = y_root
def is_connected(self, x, y):
# If two given elements have the same root index,
# then they are connected.
return self.find(x) == self.find(y)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册