diff --git a/paddle/fluid/framework/details/memory_optimize_helper.cc b/paddle/fluid/framework/details/memory_optimize_helper.cc index 894d7dad2e623649fe96b00bb515c9605c89a404..1af57dc4087d2fd734c43e9549a4bd4526af4d35 100644 --- a/paddle/fluid/framework/details/memory_optimize_helper.cc +++ b/paddle/fluid/framework/details/memory_optimize_helper.cc @@ -131,16 +131,7 @@ size_t NodeSize(const VarDesc& node) { return type_size * std::abs(size); } -size_t NodeSize(ir::Node* n) { - VarDesc* desc = nullptr; - // some op do not have block pointer - if (n->inputs[0]->Op() != nullptr) { - desc = FindVarDescInBlock(n); - } else { - desc = n->Var(); - } - return NodeSize(*desc); -} +size_t NodeSize(ir::Node* n) { return NodeSize(*(n->Var())); } std::string DebugStringImpl(VarDesc* var) { std::stringstream ss; @@ -163,24 +154,22 @@ std::string DebugStringImpl(VarDesc* var) { } std::string DebugString(ir::Node* var) { - return DebugStringImpl(FindVarDescInBlock(var)); + return DebugStringImpl(GetVarDesc(var)); } // NOTE(dzh): based ir node, if a large node has been reused // by a small size node, then next time it appear in pool, it will // have the small size. Find the original node shap from blockdesc. -VarDesc* FindVarDescInBlock(ir::Node* n) { +VarDesc* GetVarDesc(ir::Node* n) { PADDLE_ENFORCE(n->IsVar() && !n->IsCtrlVar() && n->inputs.size() == 1); - BlockDesc* block = n->inputs[0]->Op()->Block(); - PADDLE_ENFORCE(block->HasVar(n->Name()), - string::Sprintf("Block do not has var %s", n->Name())); - return block->FindVar(n->Name()); + return n->Var(); } struct NodeComparator { bool operator()(ir::Node* lhs, ir::Node* rhs) const { - auto* lhs_desc = FindVarDescInBlock(lhs); - auto* rhs_desc = FindVarDescInBlock(rhs); + if (lhs->Var()->GetType() != rhs->Var()->GetType()) return false; + auto* lhs_desc = GetVarDesc(lhs); + auto* rhs_desc = GetVarDesc(rhs); // match data type if (lhs_desc->GetDataType() != rhs_desc->GetDataType()) { return false; @@ -204,7 +193,7 @@ void OrderedSet::Insert(ir::Node* var) { return; } - auto* var_desc = FindVarDescInBlock(var); + auto* var_desc = var->Var(); auto var_shape = var_desc->GetShape(); int batch_size = static_cast(var_shape[0]); @@ -212,7 +201,7 @@ void OrderedSet::Insert(ir::Node* var) { Iter it = nodes_.begin(); while (it != nodes_.end()) { auto& prev = it->front(); - auto* cache_desc = FindVarDescInBlock(prev); + auto* cache_desc = GetVarDesc(prev); int cache_batch_size = cache_desc->GetShape()[0]; if ((cache_batch_size == -1 && batch_size == -1) || (cache_batch_size != -1 && batch_size != -1)) { @@ -336,10 +325,16 @@ int MinChunkSize() { bool NodeCanReused(const VarDesc& node) { auto type = node.GetType(); // only these types holds bulk of gpu memory - if (!(type == proto::VarType::LOD_TENSOR || - type == proto::VarType::LOD_TENSOR_ARRAY)) { - return false; - } + // FIXME(liuwei1031) did not find good ways to test SELECTED_ROWS and + // LOD_TENSOR_ARRAY re-use logic, + // disable them in version 1.4 + // if (!(type == proto::VarType::LOD_TENSOR || + // type == proto::VarType::SELECTED_ROWS || + // type == proto::VarType::LOD_TENSOR_ARRAY)) { + // return false; + // } + if (type != proto::VarType::LOD_TENSOR) return false; + // persistable variable is parameter if (node.Persistable()) { return false; diff --git a/paddle/fluid/framework/details/memory_optimize_helper.h b/paddle/fluid/framework/details/memory_optimize_helper.h index b5348cc66eaa446719b299b63caa340eab3e2ab9..65c7017d2d462976cf8cd4d7b5f660e279e12b6a 100644 --- a/paddle/fluid/framework/details/memory_optimize_helper.h +++ b/paddle/fluid/framework/details/memory_optimize_helper.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include "paddle/fluid/framework/data_type.h" @@ -140,11 +141,7 @@ size_t NodeSize(const VarDesc&); std::string DebugString(ir::Node* var); -// NOTE(dzhwinter) -// after node reuse, the replaced node shape is -// different with its VarDesc. So need to find the -// correct VarDesc in Block. -VarDesc* FindVarDescInBlock(ir::Node* n); +VarDesc* GetVarDesc(ir::Node* n); static inline bool IsSameDesc(OpDesc* op1, OpDesc* op2) { return op1->Type() == op2->Type() && op1->Inputs() == op2->Inputs() && diff --git a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc index dcc48fb934e7a06f2e85fa34fde335261f551415..a8720ff4bfb5c7fa7aee6d23949b030c328b90e6 100644 --- a/paddle/fluid/framework/ir/multi_batch_merge_pass.cc +++ b/paddle/fluid/framework/ir/multi_batch_merge_pass.cc @@ -84,7 +84,8 @@ void BatchMergePass::ApplyImpl(ir::Graph* graph) const { // 1. record op nodes of different roles for (auto node : nodes) { - if (node->IsVar()) continue; + if (!node->IsOp()) continue; + PADDLE_ENFORCE(node->Op(), "must find opdesc"); int op_role = boost::get(node->Op()->GetAttr( framework::OpProtoAndCheckerMaker::OpRoleAttrName())); if ((op_role == static_cast(framework::OpRole::kForward)) || diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 036d2a50a4a7ea3ce7e052a56202b1d54465b03e..bc03285a4c5fe6db2abf2b271d6ddc86e75a9412 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -122,14 +122,14 @@ class Autograd { std::map> input_grads = ready_op->ApplyGrad(); - for (auto it : input_grads) { - const std::vector& ingrads = it.second; + for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) { + const std::vector& ingrads = it->second; for (size_t i = 0; i < ingrads.size(); ++i) { if (!ingrads[i]) continue; - if (ready_op->input_vars_[it.first][i]->IsStopGradient()) { + if (ready_op->input_vars_[it->first][i]->IsStopGradient()) { continue; } - OpBase* pre_op = ready_op->pre_ops_[it.first][i]; + OpBase* pre_op = ready_op->pre_ops_[it->first][i]; if (!pre_op) continue; dep_counts[pre_op] -= 1; diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 6a31185b097bc0ddf93a6e32e61ac0a9f2d04cfd..647913cc80727786379e2e5525b372818e423d23 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -148,20 +148,20 @@ inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_con if(WITH_MKLDNN) set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8") if (NOT EXISTS ${INT8_DATA_DIR}) - inference_download_and_uncompress(${INT8_DATA_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "imagenet_val_100.tar.gz") + inference_download_and_uncompress(${INT8_DATA_DIR} ${INFERENCE_URL}"/int8" "imagenet_val_100.tar.gz") endif() #resnet50 int8 set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") if (NOT EXISTS ${INT8_RESNET50_MODEL_DIR}) - inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "resnet50_int8_model.tar.gz" ) + inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} ${INFERENCE_URL}"/int8" "resnet50_int8_model.tar.gz" ) endif() inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) #mobilenet int8 set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet") if (NOT EXISTS ${INT8_MOBILENET_MODEL_DIR}) - inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "mobilenetv1_int8_model.tar.gz" ) + inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} ${INFERENCE_URL}"/int8" "mobilenetv1_int8_model.tar.gz" ) endif() inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) endif() diff --git a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py index 4d968c83d9c9bf9d947204d73f4460e62039cdda..842865933f2b4741aea034b19952d4c59344ba06 100644 --- a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py +++ b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py @@ -1,5 +1,4 @@ # copyright (c) 2019 paddlepaddle authors. all rights reserved. -# # licensed under the apache license, version 2.0 (the "license"); # you may not use this file except in compliance with the license. # you may obtain a copy of the license at @@ -11,6 +10,7 @@ # without warranties or conditions of any kind, either express or implied. # see the license for the specific language governing permissions and # limitations under the license. +import hashlib import unittest import os import numpy as np @@ -21,16 +21,20 @@ import functools import contextlib from PIL import Image, ImageEnhance import math -from paddle.dataset.common import download +from paddle.dataset.common import download, md5file +import tarfile random.seed(0) np.random.seed(0) DATA_DIM = 224 - SIZE_FLOAT32 = 4 SIZE_INT64 = 8 - +FULL_SIZE_BYTES = 30106000008 +FULL_IMAGES = 50000 +DATA_DIR_NAME = 'ILSVRC2012' +IMG_DIR_NAME = 'var' +TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2' img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) @@ -70,19 +74,9 @@ def process_image(img_path, mode, color_jitter, rotate): return img -def download_unzip(): - int8_download = 'int8/download' - - target_name = 'data' - - cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + - int8_download) - - target_folder = os.path.join(cache_folder, target_name) - +def download_concat(cache_folder, zip_path): data_urls = [] data_md5s = [] - data_urls.append( 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa' ) @@ -91,72 +85,138 @@ def download_unzip(): 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab' ) data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5') - file_names = [] - + print("Downloading full ImageNet Validation dataset ...") for i in range(0, len(data_urls)): download(data_urls[i], cache_folder, data_md5s[i]) - file_names.append(data_urls[i].split('/')[-1]) - - zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') - + file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1]) + file_names.append(file_name) + print("Downloaded part {0}\n".format(file_name)) if not os.path.exists(zip_path): - cat_command = 'cat' - for file_name in file_names: - cat_command += ' ' + os.path.join(cache_folder, file_name) - cat_command += ' > ' + zip_path - os.system(cat_command) - print('Data is downloaded at {0}\n').format(zip_path) - - if not os.path.exists(target_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path) - os.system(cmd) - print('Data is unzipped at {0}\n'.format(target_folder)) - - data_dir = os.path.join(target_folder, 'ILSVRC2012') - print('ILSVRC2012 full val set at {0}\n'.format(data_dir)) - return data_dir + with open(zip_path, "w+") as outfile: + for fname in file_names: + with open(fname) as infile: + outfile.write(infile.read()) + + +def extract(zip_path, extract_folder): + data_dir = os.path.join(extract_folder, DATA_DIR_NAME) + img_dir = os.path.join(data_dir, IMG_DIR_NAME) + print("Extracting...\n") + + if not (os.path.exists(img_dir) and + len(os.listdir(img_dir)) == FULL_IMAGES): + tar = tarfile.open(zip_path) + tar.extractall(path=extract_folder) + tar.close() + print('Extracted. Full Imagenet Validation dataset is located at {0}\n'. + format(data_dir)) + + +def print_processbar(done, total): + done_filled = done * '=' + empty_filled = (total - done) * ' ' + percentage_done = done * 100 / total + sys.stdout.write("\r[%s%s]%d%%" % + (done_filled, empty_filled, percentage_done)) + sys.stdout.flush() + + +def check_integrity(filename, target_hash): + print('\nThe binary file exists. Checking file integrity...\n') + md = hashlib.md5() + count = 0 + total_parts = 50 + chunk_size = 8192 + onepart = FULL_SIZE_BYTES / chunk_size / total_parts + with open(filename) as ifs: + while True: + buf = ifs.read(8192) + if count % onepart == 0: + done = count / onepart + print_processbar(done, total_parts) + count = count + 1 + if not buf: + break + md.update(buf) + hash1 = md.hexdigest() + if hash1 == target_hash: + return True + else: + return False -def reader(): - data_dir = download_unzip() - file_list = os.path.join(data_dir, 'val_list.txt') - output_file = os.path.join(data_dir, 'int8_full_val.bin') +def convert(file_list, data_dir, output_file): + print('Converting 50000 images to binary file ...\n') with open(file_list) as flist: lines = [line.strip() for line in flist] num_images = len(lines) - if not os.path.exists(output_file): - print( - 'Preprocessing to binary file......\n' - ) - with open(output_file, "w+b") as of: - #save num_images(int64_t) to file - of.seek(0) - num = np.array(int(num_images)).astype('int64') - of.write(num.tobytes()) - for idx, line in enumerate(lines): - img_path, label = line.split() - img_path = os.path.join(data_dir, img_path) - if not os.path.exists(img_path): - continue - - #save image(float32) to file - img = process_image( - img_path, 'val', color_jitter=False, rotate=False) - np_img = np.array(img) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 - * idx) - of.write(np_img.astype('float32').tobytes()) - - #save label(int64_t) to file - label_int = (int)(label) - np_label = np.array(label_int) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 - * num_images + idx * SIZE_INT64) - of.write(np_label.astype('int64').tobytes()) - - print('The preprocessed binary file path {}\n'.format(output_file)) + with open(output_file, "w+b") as ofs: + #save num_images(int64_t) to file + ofs.seek(0) + num = np.array(int(num_images)).astype('int64') + ofs.write(num.tobytes()) + per_parts = 1000 + full_parts = FULL_IMAGES / per_parts + print_processbar(0, full_parts) + for idx, line in enumerate(lines): + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + if not os.path.exists(img_path): + continue + + #save image(float32) to file + img = process_image( + img_path, 'val', color_jitter=False, rotate=False) + np_img = np.array(img) + ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * + idx) + ofs.write(np_img.astype('float32').tobytes()) + ofs.flush() + + #save label(int64_t) to file + label_int = (int)(label) + np_label = np.array(label_int) + ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * + num_images + idx * SIZE_INT64) + ofs.write(np_label.astype('int64').tobytes()) + ofs.flush() + if (idx + 1) % per_parts == 0: + done = (idx + 1) / per_parts + print_processbar(done, full_parts) + print("Conversion finished.") + + +def run_convert(): + print('Start to download and convert 50000 images to binary file...') + cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download') + extract_folder = os.path.join(cache_folder, 'full_data') + data_dir = os.path.join(extract_folder, DATA_DIR_NAME) + file_list = os.path.join(data_dir, 'val_list.txt') + zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') + output_file = os.path.join(cache_folder, 'int8_full_val.bin') + retry = 0 + try_limit = 3 + + while not (os.path.exists(output_file) and + os.path.getsize(output_file) == FULL_SIZE_BYTES and + check_integrity(output_file, TARGET_HASH)): + if os.path.exists(output_file): + sys.stderr.write( + "\n\nThe existing binary file is broken. Start to generate new one...\n\n". + format(output_file)) + os.remove(output_file) + if retry < try_limit: + retry = retry + 1 + else: + raise RuntimeError( + "Can not convert the dataset to binary file with try limit {0}". + format(try_limit)) + download_concat(cache_folder, zip_path) + extract(zip_path, extract_folder) + convert(file_list, data_dir, output_file) + print("\nSuccess! The binary file can be found at {0}".format(output_file)) if __name__ == '__main__': - reader() + run_convert() diff --git a/paddle/fluid/inference/tests/test.cmake b/paddle/fluid/inference/tests/test.cmake index df7af71d9b32ba11822e066f574146cfa5c50edd..fc6de70f5a89331cb8940b34c1c9ff5a164c2894 100644 --- a/paddle/fluid/inference/tests/test.cmake +++ b/paddle/fluid/inference/tests/test.cmake @@ -11,7 +11,7 @@ function(inference_download INSTALL_DIR URL FILENAME) ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${INSTALL_DIR} URL ${URL}/${FILENAME} - DOWNLOAD_COMMAND wget -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} + DOWNLOAD_COMMAND wget --no-check-certificate -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} DOWNLOAD_DIR ${INSTALL_DIR} DOWNLOAD_NO_PROGRESS 1 CONFIGURE_COMMAND "" @@ -30,7 +30,7 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME) ${EXTERNAL_PROJECT_NAME} ${EXTERNAL_PROJECT_LOG_ARGS} PREFIX ${INSTALL_DIR} - DOWNLOAD_COMMAND wget -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} && + DOWNLOAD_COMMAND wget --no-check-certificate -q -O ${INSTALL_DIR}/${FILENAME} ${URL}/${FILENAME} && ${CMAKE_COMMAND} -E tar xzf ${INSTALL_DIR}/${FILENAME} DOWNLOAD_DIR ${INSTALL_DIR} DOWNLOAD_NO_PROGRESS 1 diff --git a/paddle/fluid/operators/dgc_clip_by_norm_op.h b/paddle/fluid/operators/dgc_clip_by_norm_op.h index bd22d16f7a21877af4e78c30f7e0985c64b543f2..197bf59b2a470e1f6e4e31c6706d1e3f8e73fbbc 100644 --- a/paddle/fluid/operators/dgc_clip_by_norm_op.h +++ b/paddle/fluid/operators/dgc_clip_by_norm_op.h @@ -24,18 +24,21 @@ class DGCClipByNormKernel : public ClipByNormKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto rampup_begin_step = context.Attr("rampup_begin_step"); - if (static_cast(rampup_begin_step) >= 0) { - auto current_step_tensor = - context.Input("current_step"); - auto* current_step = current_step_tensor->data(); - - if (static_cast(*current_step) < - static_cast(rampup_begin_step)) { - VLOG(10) << "current_step:" << *current_step - << " < rampup_begin_step:" << rampup_begin_step - << " so does't use dgc_clip_by_norm"; - return; - } + if (static_cast(rampup_begin_step) < 0) { + return; + } + + auto current_step_tensor = context.Input("current_step"); + auto* current_step = current_step_tensor->data(); + + VLOG(10) << "current_step:" << *current_step + << ", rampup_begin_step:" << rampup_begin_step; + + if (static_cast(*current_step) < static_cast(rampup_begin_step)) { + VLOG(10) << "current_step:" << *current_step + << " < rampup_begin_step:" << rampup_begin_step + << " so does't use dgc_clip_by_norm"; + return; } return ClipByNormKernel::Compute(context); diff --git a/python/paddle/fluid/dygraph/layer_object_helper.py b/python/paddle/fluid/dygraph/layer_object_helper.py index c56652e103ce93bf5459b30b66c7b1f04e7c14d0..f0be5ff3bf2394f1f7da8fbcc341a0d2dfacdab3 100644 --- a/python/paddle/fluid/dygraph/layer_object_helper.py +++ b/python/paddle/fluid/dygraph/layer_object_helper.py @@ -65,7 +65,7 @@ class LayerObjectHelper(LayerHelperBase): def _input(self, inputs_in): inputs = self._multiple_input(inputs_in) if len(inputs) != 1: - raise "{0} layer only takes one input".format(self.layer_type) + raise "{0} layer only takes one input in".format(self.layer_type) return inputs[0] def _multiple_param_attr(self, length, param_attr_in=None): @@ -74,7 +74,8 @@ class LayerObjectHelper(LayerHelperBase): param_attr = [param_attr] if len(param_attr) != 1 and len(param_attr) != length: - raise ValueError("parameter number mismatch") + raise ValueError("parameter number mismatch in {}".format( + self.name)) elif len(param_attr) == 1 and length != 1: tmp = [None] * length for i in six.moves.range(length): @@ -91,6 +92,10 @@ class LayerObjectHelper(LayerHelperBase): Returns input, param_attr """ + param_attr_in = ParamAttr._to_attr(param_attr_in) + if isinstance(param_attr_in, bool): + raise ValueError('Param_attr should not be False in {}'.format( + self.name)) inputs = inputs_in if (inputs_in is not None) else [] inputs = self._multiple_input(inputs) param_attrs = self._multiple_param_attr(len(inputs), param_attr_in) @@ -112,8 +117,8 @@ class LayerObjectHelper(LayerHelperBase): if dtype is None: dtype = each.dtype elif dtype != each.dtype: - raise ValueError("Data Type mismatch: %d to %d" % - (dtype, each.dtype)) + raise ValueError("Data Type mismatch: %d to %d in %s" % + (dtype, each.dtype, self.name)) return dtype def get_parameter(self, name): @@ -126,7 +131,8 @@ class LayerObjectHelper(LayerHelperBase): """ param = self.main_program.global_block().var(name) if not isinstance(param, Parameter): - raise ValueError("no Parameter name %s found" % name) + raise ValueError("no Parameter name %s found in %s" % + (name, self.name)) return param def append_bias_op(self, @@ -184,7 +190,8 @@ class LayerObjectHelper(LayerHelperBase): if isinstance(act, six.string_types): act = {'type': act} else: - raise TypeError(str(act) + " should be unicode or str") + raise TypeError( + str(act) + " should be unicode or str in %s ", self.name) if (use_cudnn is not None) and use_cudnn: act['use_cudnn'] = use_cudnn @@ -211,5 +218,6 @@ class LayerObjectHelper(LayerHelperBase): """ param = param if not isinstance(param, cls): - raise TypeError("The input {0} parameter of method {1} must be {2}", - param, self.layer_type, cls.__name__) + raise TypeError( + "The input {0} parameter of method {1} must be {2}, in layer {3}", + param, self.layer_type, cls.__name__, self.name) diff --git a/python/paddle/fluid/dygraph/nn.py b/python/paddle/fluid/dygraph/nn.py index bc1006d767a4993d651ded6f16bfa63b77e67586..7665703a726050b2a236b5d96bbafb6c24b2617a 100644 --- a/python/paddle/fluid/dygraph/nn.py +++ b/python/paddle/fluid/dygraph/nn.py @@ -20,7 +20,7 @@ import numpy as np from .. import core from ..layers import utils from . import layers -from ..framework import Variable, OpProtoHolder +from ..framework import Variable, OpProtoHolder, Parameter from ..layers import layer_function_generator from ..param_attr import ParamAttr from ..initializer import Normal, Constant, NumpyArrayInitializer @@ -460,46 +460,69 @@ class FC(layers.Layer): self._param_attr = param_attr self._bias_attr = bias_attr self._act = act + self.__w = list() - def _build_once(self, input): - input_shape = input.shape - param_shape = [ - reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], 1) - ] + [self._size] - self._w = self.create_parameter( - attr=self._param_attr, - shape=param_shape, - dtype=self._dtype, - is_bias=False) + @property + def _w(self, i=0): + return self.__w[i] - if self._bias_attr: - size = list([self._size]) - self._b = self.create_parameter( - attr=self._bias_attr, - shape=size, - dtype=self._dtype, - is_bias=True) - else: - self._b = None + @_w.setter + def _w(self, value, i=0): + assert isinstance(value, Parameter) + self.__w[i] = value - def forward(self, input): - tmp = self._helper.create_variable_for_type_inference(self._dtype) - self._helper.append_op( - type="mul", - inputs={"X": input, - "Y": self._w}, - outputs={"Out": tmp}, - attrs={ - "x_num_col_dims": self._num_flatten_dims, - "y_num_col_dims": 1 - }) + def _build_once(self, input): + i = 0 + for inp, param in self._helper.iter_inputs_and_params(input, + self._param_attr): + input_shape = inp.shape + + param_shape = [ + reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], + 1) + ] + [self._size] + self.__w.append( + self.add_parameter( + '_w%d' % i, + self.create_parameter( + attr=param, + shape=param_shape, + dtype=self._dtype, + is_bias=False))) + i += 1 + + size = list([self._size]) + self._b = self.create_parameter( + attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True) - pre_bias = self._helper.create_variable_for_type_inference(self._dtype) - self._helper.append_op( - type="sum", - inputs={"X": [tmp]}, - outputs={"Out": pre_bias}, - attrs={"use_mkldnn": False}) + def forward(self, input): + mul_results = list() + i = 0 + for inp, param in self._helper.iter_inputs_and_params(input, + self._param_attr): + tmp = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="mul", + inputs={"X": inp, + "Y": self.__w[i]}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": self._num_flatten_dims, + "y_num_col_dims": 1 + }) + i += 1 + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] + else: + pre_bias = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}, + attrs={"use_mkldnn": False}) if self._b: pre_activation = self._helper.create_variable_for_type_inference( diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index b583f17c9619afb63b959e215c5f79ddeefa5173..b8bc4e819e78636b3f227534e01182d9333b8a14 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -493,7 +493,8 @@ class Variable(object): self._ivar._run_backward() def gradient(self): - return np.array(self._ivar._grad_value()) + new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True) + return np.array(new_ivar.value().get_tensor()) def clear_gradient(self): self._ivar._clear_gradient() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 79accabe87869c832b7467acbaf70d11cbca8a96..7e6e37116fe23f26eb14dd0573dbe031aec98dd8 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -832,7 +832,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): type=x.type, name=name, dtype=x.dtype, persistable=False) helper.append_op( - type="clip_by_norm", + type="dgc_clip_by_norm", inputs={"X": x, "current_step": self._global_step_var}, attrs={ @@ -845,7 +845,7 @@ class DGCMomentumOptimizer(MomentumOptimizer): def _append_clip_norm(self, grad_var, clip_norm): with grad_var.block.program._backward_role_guard(): return self._clip_by_norm( - x=grad_var, max_norm=clip_norm, name=grad_var.name + "@DGC") + x=grad_var, max_norm=clip_norm, name=grad_var.name) def _dgc_op(self, param_var, clip_var, grad_var, u_var, v_var, k_var, encoded_var): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py index 947801334d4a0f4cf6e19be3e45a0cb80209bab0..813ac513dae93b488cc2a686913bdf75ddbbc87b 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_transformer.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_transformer.py @@ -304,7 +304,7 @@ use_py_reader = False sync = False # how many batches we use -batch_num = 2 +batch_num = 50 np.random.seed = 1 src_word_np = np.random.randint( @@ -1076,19 +1076,19 @@ class TestDygraphTransformer(unittest.TestCase): static_param_updated[static_param_name_list[k - 4]] = out[k] - self.assertTrue(np.allclose(static_avg_cost_value, dy_avg_cost.numpy())) - self.assertTrue(np.allclose(static_sum_cost_value, dy_sum_cost.numpy())) self.assertTrue( - np.allclose( - static_predict_value, dy_predict.numpy(), atol=1e-5)) + np.array_equal(static_avg_cost_value, dy_avg_cost.numpy())) self.assertTrue( - np.allclose(static_token_num_value, dy_token_num.numpy())) + np.array_equal(static_sum_cost_value, dy_sum_cost.numpy())) + self.assertTrue( + np.array_equal(static_predict_value, dy_predict.numpy())) + self.assertTrue( + np.array_equal(static_token_num_value, dy_token_num.numpy())) + for key, value in six.iteritems(static_param_init): - self.assertTrue(np.allclose(value, dy_param_init[key])) + self.assertTrue(np.array_equal(value, dy_param_init[key])) for key, value in six.iteritems(static_param_updated): - self.assertTrue( - np.allclose( - value, dy_param_updated[key], atol=1e-4)) + self.assertTrue(np.array_equal(value, dy_param_updated[key])) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 223cdcc386257d37d57245797dc81eeb9ecb5c89..2fd82884f4a8d07ced06d5775e5b414ebc2ef654 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -76,6 +76,41 @@ class LayerTest(unittest.TestCase): class TestLayer(LayerTest): + def test_fc(self): + # pdb.set_trace() + inp = np.ones([3, 32, 32], dtype='float32') + with self.static_graph(): + t = layers.data( + name='data', + shape=[3, 32, 32], + dtype='float32', + append_batch_size=False) + ret = layers.fc(t, size=4, bias_attr=False, num_flatten_dims=1) + ret2 = layers.fc(ret, size=4) + static_ret = self.get_static_graph_result( + feed={'data': inp}, fetch_list=[ret2])[0] + with self.static_graph(): + t = layers.data( + name='data', + shape=[3, 32, 32], + dtype='float32', + append_batch_size=False) + fc1 = nn.FC('fc1', size=4, bias_attr=False, num_flatten_dims=1) + fc2 = nn.FC('fc2', size=4) + ret = fc1(t) + ret2 = fc2(ret) + static_ret2 = self.get_static_graph_result( + feed={'data': inp}, fetch_list=[ret2])[0] + with self.dynamic_graph(): + t = base.to_variable(inp) + fc1 = nn.FC('fc1', size=4, bias_attr=False, num_flatten_dims=1) + fc2 = nn.FC('fc2', size=4) + ret = fc1(t) + dy_ret = fc2(ret) + + self.assertTrue(np.array_equal(static_ret, static_ret2)) + self.assertTrue(np.array_equal(static_ret, dy_ret._numpy())) + def test_layer_norm(self): inp = np.ones([3, 32, 32], dtype='float32') with self.static_graph():