提交 e021ad67 编写于 作者: Y Yang Yang

Merge remote-tracking branch 'upstream/develop' into backward_on_parallel_do

......@@ -56,6 +56,7 @@ ExternalProject_Add(
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DMKLROOT=${MKLML_ROOT}
CMAKE_ARGS -DCMAKE_C_FLAGS=${MKLDNN_CFLAG}
CMAKE_ARGS -DCMAKE_CXX_FLAGS=${MKLDNN_CXXFLAG}
......
......@@ -130,8 +130,6 @@ class Tensor {
inline void set_layout(const DataLayout layout) { layout_ = layout; }
private:
friend class LoDTensor;
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of Variable.
......
......@@ -151,7 +151,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
for (int n = 0; n < batch_size; ++n) {
std::map<int, std::vector<Box>> boxes;
for (int i = label_index[n]; i < label_index[n + 1]; ++i) {
for (size_t i = label_index[n]; i < label_index[n + 1]; ++i) {
Box box(labels(i, 2), labels(i, 3), labels(i, 4), labels(i, 5));
int label = labels(i, 0);
auto is_difficult = labels(i, 1);
......@@ -167,7 +167,7 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
auto detect_index = detect_lod[0];
for (int n = 0; n < batch_size; ++n) {
std::map<int, std::vector<std::pair<T, Box>>> boxes;
for (int i = detect_index[n]; i < detect_index[n + 1]; ++i) {
for (size_t i = detect_index[n]; i < detect_index[n + 1]; ++i) {
Box box(detect(i, 2), detect(i, 3), detect(i, 4), detect(i, 5));
int label = detect(i, 0);
auto score = detect(i, 1);
......@@ -269,8 +269,8 @@ class DetectionMAPOpKernel : public framework::OpKernel<T> {
std::map<int, std::vector<std::pair<T, int>>>& pos) {
const T* pos_data = pos_tensor.data<T>();
auto pos_data_lod = pos_tensor.lod();
for (int i = 0; i < pos_data_lod.size(); ++i) {
for (int j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) {
for (size_t i = 0; i < pos_data_lod.size(); ++i) {
for (size_t j = pos_data_lod[0][i]; j < pos_data_lod[0][i + 1]; ++j) {
T score = pos_data[j * 2];
int flag = 1;
if (pos_data[j * 2 + 1] < kEPS) flag = 0;
......
......@@ -107,6 +107,9 @@ class DataFeeder(object):
dtype=dtype))
for each_sample in iterable:
assert len(each_sample) == len(converter), (
"The number of fields in data (%s) does not match " +
"len(feed_list) (%s)") % (len(each_sample), len(converter))
for each_converter, each_slot in six.zip(converter, each_sample):
each_converter.feed(each_slot)
ret_dict = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册