From 8140485aee3ac18b70b6abc8209fd853a70e48ce Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 2 Apr 2021 11:41:18 +0800 Subject: [PATCH] [Cherry-Pick] logclean & embedding doc (#32009) * fix en doc for emb (#31980) * fix en doc for emb, test=document_fix; Change-Id: I4757e67caacd7189f068493ed45a7445f87ffb40 * LOG CLEAN (#31819) * upgrade vlog * train from dataset fetch optimize --- cmake/external/brpc.cmake | 2 +- .../distributed/service/brpc_ps_server.cc | 5 +-- .../fluid/distributed/service/brpc_utils.cc | 2 +- paddle/fluid/distributed/service/env.h | 10 +++--- paddle/fluid/distributed/service/ps_client.cc | 3 +- paddle/fluid/distributed/service/service.cc | 2 +- .../fluid/distributed/table/depends/dense.h | 2 -- .../fluid/distributed/table/depends/sparse.h | 2 -- .../fluid/framework/details/build_strategy.cc | 28 +++++++---------- paddle/fluid/framework/device_worker.h | 2 +- paddle/fluid/framework/hogwild_worker.cc | 31 +++++++++++++++---- paddle/fluid/platform/lodtensor_printer.cc | 30 +++++++++++++----- paddle/fluid/platform/lodtensor_printer.h | 2 +- .../fluid/platform/lodtensor_printer_test.cc | 3 +- .../distributed/fleet/base/fleet_base.py | 13 ++++---- .../distributed/fleet/runtime/the_one_ps.py | 2 +- .../fluid/tests/unittests/test_monitor.py | 17 ++++++++-- python/paddle/nn/functional/input.py | 4 +-- python/paddle/nn/layer/common.py | 14 ++++----- 19 files changed, 104 insertions(+), 70 deletions(-) diff --git a/cmake/external/brpc.cmake b/cmake/external/brpc.cmake index 0eb590c42d0..582c06e88c1 100644 --- a/cmake/external/brpc.cmake +++ b/cmake/external/brpc.cmake @@ -41,7 +41,7 @@ ExternalProject_Add( ${EXTERNAL_PROJECT_LOG_ARGS} # TODO(gongwb): change to de newst repo when they changed. GIT_REPOSITORY "https://github.com/wangjiawei04/brpc" - GIT_TAG "6d79e0b17f25107c35b705ea58d888083f59ff47" + GIT_TAG "e203afb794caf027da0f1e0776443e7d20c0c28e" PREFIX ${BRPC_SOURCES_DIR} UPDATE_COMMAND "" CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc index 7af7a436364..256baa93194 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -57,7 +57,8 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { std::unique_lock lock(mutex_); std::string ip_port = ip + ":" + std::to_string(port); - VLOG(3) << "server of rank " << _rank << " starts at " << ip_port; + VLOG(0) << "running server with rank id: " << _rank + << ", endpoint: " << ip_port; brpc::ServerOptions options; int num_threads = std::thread::hardware_concurrency(); @@ -535,7 +536,7 @@ int32_t BrpcPsService::stop_server(Table *table, auto *p_server = _server; std::thread t_stop([p_server]() { p_server->stop(); - LOG(INFO) << "Server Stoped"; + VLOG(3) << "Server Stoped"; }); t_stop.detach(); return 0; diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc index 2822c2faa20..18f16b97224 100644 --- a/paddle/fluid/distributed/service/brpc_utils.cc +++ b/paddle/fluid/distributed/service/brpc_utils.cc @@ -331,7 +331,7 @@ std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port) { while (hp->h_addr_list[i] != NULL) { int_ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]); - VLOG(0) << "Brpc Get host by name, host:" << ip << " -> ip: " << int_ip; + VLOG(3) << "Brpc Get host by name, host:" << ip << " -> ip: " << int_ip; break; } diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h index 206ff2c5cc4..d807e950f27 100644 --- a/paddle/fluid/distributed/service/env.h +++ b/paddle/fluid/distributed/service/env.h @@ -39,7 +39,7 @@ struct PSHost { // |---ip---|---port---|--rank--| // |-32bit--|--20bit---|--12bit-| - // for pslib + uint64_t serialize_to_uint64() { uint64_t host_label = 0; host_label = inet_addr(ip.c_str()); @@ -174,14 +174,12 @@ class PSEnvironment { host.ip = ip; host.port = port; host.rank = rank; - if (sign_set.count(rank) > 0) { - LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port - << ", rank:" << host.rank - << " already register, ignore register"; - } else { + + if (sign_set.count(rank) == 0) { host_list.push_back(host); sign_set.insert(rank); } + return 0; } diff --git a/paddle/fluid/distributed/service/ps_client.cc b/paddle/fluid/distributed/service/ps_client.cc index 866200e7740..709d7aee3d2 100644 --- a/paddle/fluid/distributed/service/ps_client.cc +++ b/paddle/fluid/distributed/service/ps_client.cc @@ -80,8 +80,7 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { } TableManager::instance().initialize(); - LOG(INFO) << "Create PSClient[" << service_param.client_class() - << "] success"; + VLOG(3) << "Create PSClient[" << service_param.client_class() << "] success"; return client; } } // namespace distributed diff --git a/paddle/fluid/distributed/service/service.cc b/paddle/fluid/distributed/service/service.cc index 1d360eb5669..a2c43803838 100644 --- a/paddle/fluid/distributed/service/service.cc +++ b/paddle/fluid/distributed/service/service.cc @@ -47,7 +47,7 @@ paddle::distributed::PSParameter load_from_prototxt( } void PSCore::init_gflag(const std::string& gflags) { - LOG(INFO) << "Init With Gflags:" << gflags; + VLOG(3) << "Init With Gflags:" << gflags; std::vector flags = paddle::string::split_string(gflags); if (flags.size() < 1) { flags.push_back("-max_body_size=314217728"); diff --git a/paddle/fluid/distributed/table/depends/dense.h b/paddle/fluid/distributed/table/depends/dense.h index 209595de7e6..bf6c12d2b48 100644 --- a/paddle/fluid/distributed/table/depends/dense.h +++ b/paddle/fluid/distributed/table/depends/dense.h @@ -89,7 +89,6 @@ class DSGD : public DenseOptimizer { auto blas = GetBlas(); float lr = *(global_learning_rate_) * (*learning_rate); - VLOG(4) << "DSGD LearningRate: " << lr; blas.VCOPY(update_numel, update_values + begin, grads.data()); blas.SCAL(update_numel, lr, grads.data()); blas.VSUB(update_numel, param + begin, grads.data(), param + begin); @@ -157,7 +156,6 @@ class DAdam : public DenseOptimizer { beta2_pow[0] = beta2_pow[0] * beta2; float lr_ = *(global_learning_rate_)*learning_rate[0]; - VLOG(4) << "DAdam LearningRate: " << lr_; lr_ *= sqrt(1 - beta2_pow[0]) / (1 - beta1_pow[0]); float* tmp_ = tmp.data(); diff --git a/paddle/fluid/distributed/table/depends/sparse.h b/paddle/fluid/distributed/table/depends/sparse.h index 38ae03777c8..8b41dcee8b5 100644 --- a/paddle/fluid/distributed/table/depends/sparse.h +++ b/paddle/fluid/distributed/table/depends/sparse.h @@ -110,7 +110,6 @@ class SSGD : public SparseOptimizer { auto* value = block->Get(id); float learning_rate = *(global_learning_rate_) * (value + lr_offset)[0]; - VLOG(4) << "SSGD LearningRate: " << learning_rate; float* param = value + param_offset; std::vector grads; @@ -166,7 +165,6 @@ class SAdam : public SparseOptimizer { if (!block->GetEntry(id)) continue; auto* values = block->Get(id); float lr_ = *(global_learning_rate_) * (values + lr_offset)[0]; - VLOG(4) << "SAdam LearningRate: " << lr_; float* param = values + param_offset; float* moment1 = values + m1_offset; float* moment2 = values + m2_offset; diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index c045dae4717..ef402b65c01 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -167,9 +167,6 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPassWithCheck(strategy_.fuse_bn_add_act_ops_, "fuse_bn_add_act_pass"); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); -#else - LOG(WARNING) << "fusion_group is not enabled for Windows/MacOS now, and " - "only effective when running with CUDA GPU."; #endif AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_, "fuse_elewise_add_act_pass"); @@ -271,12 +268,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { if (FLAGS_use_mkldnn) { AppendPass(pass_name); } else if (!strategy_.mkldnn_enabled_op_types_.empty()) { - LOG(WARNING) - << "mkldnn_enabled_op_types specify the operator type list to " - "use MKLDNN acceleration. It is null in default, means " - "that all the operators supported by MKLDNN will be " - "accelerated. And it should not be set when " - "FLAGS_use_mkldnn=false."; + VLOG(1) << "mkldnn_enabled_op_types specify the operator type list to " + "use MKLDNN acceleration. It is null in default, means " + "that all the operators supported by MKLDNN will be " + "accelerated. And it should not be set when " + "FLAGS_use_mkldnn=false."; } #else PADDLE_ENFORCE_NE(FLAGS_use_mkldnn, true, @@ -409,26 +405,26 @@ ir::Graph *BuildStrategy::Apply(ir::Graph *graph, << ", num_trainers:" << num_trainers_; } else if (pass->Type() == "fuse_relu_depthwise_conv_pass") { if (use_device != p::kCUDA) { - LOG(WARNING) << "fuse_relu_depthwise_conv_pass is only supported on " - "GPU, skipped."; + VLOG(1) << "fuse_relu_depthwise_conv_pass is only supported on " + "GPU, skipped."; continue; } } else if (pass->Type() == "fusion_group_pass") { pass->Set("use_gpu", new bool((use_device == p::kCUDA))); if (use_device != p::kCUDA) { - LOG(WARNING) << "fusion_group_pass is only supported on GPU, skipped."; + VLOG(1) << "fusion_group_pass is only supported on GPU, skipped."; continue; } } else if (pass->Type() == "fuse_bn_act_pass") { if (use_device != p::kCUDA) { - LOG(WARNING) << "fuse_bn_act_pass is only supported on " - "GPU, skipped."; + VLOG(1) << "fuse_bn_act_pass is only supported on " + "GPU, skipped."; continue; } } else if (pass->Type() == "fuse_bn_add_act_pass") { if (use_device != p::kCUDA) { - LOG(WARNING) << "fuse_bn_add_act_pass is only supported on " - "GPU, skipped."; + VLOG(1) << "fuse_bn_add_act_pass is only supported on " + "GPU, skipped."; continue; } } else if (pass->Type() == "mkldnn_placement_pass") { diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 6ecc02bbae6..ca805c9243e 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -202,7 +202,7 @@ class DeviceWorker { Scope* root_scope_ = nullptr; Scope* thread_scope_; paddle::platform::Place place_; - int64_t batch_num_; + int64_t batch_num_ = 0; FetchConfig fetch_config_; bool use_cvm_; bool no_cvm_; diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index 7aaaba51046..bdadd1becda 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker_factory.h" @@ -227,14 +228,32 @@ void HogwildWorker::PrintFetchVars() { // call count batch_num_++; int batch_per_print = fetch_config_.print_period(); - if (thread_id_ == 0) { - if (batch_num_ % batch_per_print == 0) { - int fetch_var_num = fetch_config_.fetch_var_names_size(); - for (int i = 0; i < fetch_var_num; ++i) { - platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i), - fetch_config_.fetch_var_str_format(i)); + int fetch_var_num = fetch_config_.fetch_var_names_size(); + + if (fetch_var_num == 0) { + return; + } + + if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) { + time_t curtime; + time(&curtime); + char mbstr[80]; + std::strftime(mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S", + std::localtime(&curtime)); + + std::stringstream ss; + ss << "time: [" << mbstr << "], "; + ss << "batch: [" << batch_num_ << "], "; + + for (int i = 0; i < fetch_var_num; ++i) { + platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i), + fetch_config_.fetch_var_str_format(i), &ss); + if (i < fetch_var_num - 1) { + ss << ", "; } } + + std::cout << ss.str() << std::endl; } } diff --git a/paddle/fluid/platform/lodtensor_printer.cc b/paddle/fluid/platform/lodtensor_printer.cc index 0be4233269e..25ae0ab264f 100644 --- a/paddle/fluid/platform/lodtensor_printer.cc +++ b/paddle/fluid/platform/lodtensor_printer.cc @@ -27,24 +27,38 @@ namespace paddle { namespace platform { void PrintVar(framework::Scope* scope, const std::string& var_name, - const std::string& print_info) { + const std::string& print_info, std::stringstream* sstream) { framework::Variable* var = scope->FindVar(var_name); if (var == nullptr) { - VLOG(1) << "Variable Name " << var_name << " does not exist in your scope"; + VLOG(0) << "Variable Name " << var_name << " does not exist in your scope"; return; } framework::LoDTensor* tensor = var->GetMutable(); if (tensor == nullptr) { - VLOG(1) << "tensor of variable " << var_name + VLOG(0) << "tensor of variable " << var_name << " does not exist in your scope"; return; } - std::ostringstream sstream; - sstream << print_info << "\t"; - sstream << var_name << "\t"; - sstream << *tensor << "\t"; - std::cout << sstream.str() << std::endl; + *sstream << print_info << ": "; + +#define PrintTensorCallback(cpp_type, proto_type) \ + do { \ + if (tensor->type() == proto_type) { \ + *sstream << "["; \ + auto* data = tensor->data(); \ + auto element_num = tensor->numel(); \ + if (element_num > 0) { \ + *sstream << data[0]; \ + for (int j = 1; j < element_num; ++j) { \ + *sstream << " " << data[j]; \ + } \ + } \ + *sstream << "]"; \ + } \ + } while (0) + + _ForEachDataType_(PrintTensorCallback); } } // end namespace platform diff --git a/paddle/fluid/platform/lodtensor_printer.h b/paddle/fluid/platform/lodtensor_printer.h index e0bd1fff197..d30afb62b0b 100644 --- a/paddle/fluid/platform/lodtensor_printer.h +++ b/paddle/fluid/platform/lodtensor_printer.h @@ -26,6 +26,6 @@ class Scope; namespace paddle { namespace platform { void PrintVar(framework::Scope* scope, const std::string& var_name, - const std::string& print_info); + const std::string& print_info, std::stringstream* out); } // end namespace platform } // end namespace paddle diff --git a/paddle/fluid/platform/lodtensor_printer_test.cc b/paddle/fluid/platform/lodtensor_printer_test.cc index 5b2af270740..51bd55ebb7f 100644 --- a/paddle/fluid/platform/lodtensor_printer_test.cc +++ b/paddle/fluid/platform/lodtensor_printer_test.cc @@ -18,5 +18,6 @@ TEST(LodTensorPrinter, PrintVar) { paddle::framework::Scope scope; - paddle::platform::PrintVar(&scope, "NotAVar", "We don't have var"); + std::stringstream ss; + paddle::platform::PrintVar(&scope, "NotAVar", "We don't have var", &ss); } diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index f4075e92c4c..f12307cc927 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -628,12 +628,13 @@ class Fleet(object): self.user_defined_optimizer = optimizer if strategy is not None: - warnings.warn( - "It is recommended to use DistributedStrategy " - "in fleet.init(). The strategy here is only for compatibility. " - "If the strategy in fleet.distributed_optimizer() is " - "not None, then it will overwrite the DistributedStrategy in fleet.init(), " - "which will take effect in distributed training.") + if self._is_collective: + warnings.warn( + "It is recommended to use DistributedStrategy " + "in fleet.init(). The strategy here is only for compatibility. " + "If the strategy in fleet.distributed_optimizer() is " + "not None, then it will overwrite the DistributedStrategy in fleet.init(), " + "which will take effect in distributed training.") self._user_defined_strategy = copy.deepcopy(strategy) self._context = {} diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index abec4710f5d..17ac6c021af 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -767,7 +767,7 @@ class TheOnePSRuntime(RuntimeBase): server = self._get_fleet_proto(is_server=True, is_sync=is_sync) proto_txt = str(server) - debug = bool(os.getenv("PSERVER_DEBUG", "0")) + debug = bool(int(os.getenv("PSERVER_DEBUG", "0"))) if debug: print("server: \n{}".format(proto_txt)) diff --git a/python/paddle/fluid/tests/unittests/test_monitor.py b/python/paddle/fluid/tests/unittests/test_monitor.py index cf273876b1f..bea2f6c8b38 100644 --- a/python/paddle/fluid/tests/unittests/test_monitor.py +++ b/python/paddle/fluid/tests/unittests/test_monitor.py @@ -17,6 +17,8 @@ TestCases for Monitor from __future__ import print_function import paddle +paddle.enable_static() + import paddle.fluid as fluid import paddle.fluid.core as core import numpy as np @@ -52,6 +54,11 @@ class TestDatasetWithStat(unittest.TestCase): name=slot, shape=[1], dtype="int64", lod_level=1) slots_vars.append(var) + embs = [] + for x in slots_vars: + emb = fluid.layers.embedding(x, is_sparse=True, size=[100001, 4]) + embs.append(emb) + dataset = paddle.distributed.InMemoryDataset() dataset._set_batch_size(32) dataset._set_thread(3) @@ -74,11 +81,17 @@ class TestDatasetWithStat(unittest.TestCase): for i in range(self.epoch_num): for data in data_loader(): exe.run(fluid.default_main_program(), feed=data) + else: for i in range(self.epoch_num): try: - exe.train_from_dataset(fluid.default_main_program(), - dataset) + exe.train_from_dataset( + fluid.default_main_program(), + dataset, + fetch_list=[embs[0], embs[1]], + fetch_info=["emb0", "emb1"], + print_period=1) + except Exception as e: self.assertTrue(False) diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index bf389717518..b88a2b042ff 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -148,9 +148,7 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): sparse(bool): The flag indicating whether to use sparse update. This parameter only affects the performance of the backwards gradient update. It is recommended to set True because sparse update is faster. But some optimizers does not support sparse update, - such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` , - :ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` , - :ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` . + such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`. In these cases, sparse must be False. Default: False. padding_idx(int|long|None): padding_idx needs to be in the interval [-weight.shape[0], weight.shape[0]). If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 400a34d5e52..e1310e0dc78 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1219,7 +1219,7 @@ class Embedding(layers.Layer): For specific usage, refer to code examples. It implements the function of the Embedding Layer. This layer is used to lookup embeddings vector of ids provided by :attr:`x` . It automatically constructs a 2D embedding matrix based on the - input :attr:`num_embeddings` and attr:`embedding_dim`. + input :attr:`num_embeddings` and :attr:`embedding_dim`. The shape of output Tensor is generated by appending an emb_size dimension to the last dimension of the input Tensor shape. @@ -1231,9 +1231,9 @@ class Embedding(layers.Layer): Case 1: - input is a Tensor. padding_idx = -1 - input.data = [[1, 3], [2, 4], [4, 127] - input.shape = [3, 2] + x is a Tensor. padding_idx = -1 + x.data = [[1, 3], [2, 4], [4, 127] + x.shape = [3, 2] Given size = [128, 16] output is a Tensor: out.shape = [3, 2, 16] @@ -1251,7 +1251,7 @@ class Embedding(layers.Layer): Parameters: num_embeddings (int): Just one element which indicate the size of the dictionary of embeddings. - embedding_dim: Just one element which indicate the size of each embedding vector respectively. + embedding_dim (int): Just one element which indicate the size of each embedding vector respectively. padding_idx(int|long|None): padding_idx needs to be in the interval [-num_embeddings, num_embeddings). If :math:`padding\_idx < 0`, the :math:`padding\_idx` will automatically be converted to :math:`vocab\_size + padding\_idx` . It will output all-zero padding data whenever lookup @@ -1260,9 +1260,7 @@ class Embedding(layers.Layer): sparse(bool): The flag indicating whether to use sparse update. This parameter only affects the performance of the backwards gradient update. It is recommended to set True because sparse update is faster. But some optimizer does not support sparse update, - such as :ref:`api_optimizer_AdadeltaOptimizer` , :ref:`api_optimizer_AdamaxOptimizer` , - :ref:`api_optimizer_DecayedAdagradOptimizer` , :ref:`api_optimizer_FtrlOptimizer` , - :ref:`api_optimizer_LambOptimizer` and :ref:`api_optimizer_LarsMomentumOptimizer` . + such as :ref:`api_paddle_optimizer_adadelta_Adadelta` , :ref:`api_paddle_optimizer_adamax_Adamax` , :ref:`api_paddle_optimizer_lamb_Lamb`. In these case, sparse must be False. Default: False. weight_attr(ParamAttr): To specify the weight parameter property. Default: None, which means the default weight parameter property is used. See usage for details in :ref:`api_ParamAttr` . In addition, -- GitLab