提交 d26ff8cb 编写于 作者: Q Qiao Longfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into cpu-for-1.1-merge-with-shape

......@@ -142,5 +142,10 @@ def parse_args():
choices=['reduce', 'all_reduce'],
default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--fuse_broadcast_op',
action='store_true',
help='If set, would fuse multiple broadcast operators into one fused_broadcast operator.'
)
args = parser.parse_args()
return args
......@@ -177,6 +177,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
else:
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.AllReduce
build_strategy.fuse_broadcast_op = args.fuse_broadcast_op
avg_loss = train_args[0]
......@@ -240,7 +241,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
if args.use_fake_data or args.use_reader_op:
try:
fetch_ret = exe.run(fetch_list)
except fluid.core.EOFException as eof:
break
......
......@@ -355,6 +355,8 @@ paddle.fluid.optimizer.ModelAverage.__init__ ArgSpec(args=['self', 'average_wind
paddle.fluid.optimizer.ModelAverage.apply ArgSpec(args=[], varargs='args', keywords='kwds', defaults=None)
paddle.fluid.optimizer.ModelAverage.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.optimizer.ModelAverage.restore ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None)
paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None))
paddle.fluid.optimizer.LarsMomentumOptimizer.minimize ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.backward.append_backward ArgSpec(args=['loss', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.regularizer.L1DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))
paddle.fluid.regularizer.L2DecayRegularizer.__init__ ArgSpec(args=['self', 'regularization_coeff'], varargs=None, keywords=None, defaults=(0.0,))
......
......@@ -16,12 +16,14 @@ if(WITH_GPU)
dynload_cuda variable_visitor)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda)
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim)
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif()
cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor)
......@@ -34,7 +36,7 @@ if(WITH_GPU)
endif()
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle fused_broadcast_op_handle)
if(WITH_GPU)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS graph framework_proto reference_count_pass)
......@@ -58,4 +60,4 @@ cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executo
cc_library(build_strategy SRCS build_strategy.cc DEPS
graph_viz_pass multi_devices_graph_pass
multi_devices_graph_print_pass multi_devices_graph_check_pass
fuse_elewise_add_act_pass)
fuse_elewise_add_act_pass multi_batch_merge_pass)
......@@ -48,8 +48,15 @@ void BroadcastOpHandle::RunImpl() {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
BroadcastOneVar(*in_var_handle, out_var_handles, var_scopes);
}
void BroadcastOpHandle::BroadcastOneVar(
const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) {
auto *in_var =
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
if (!in_tensor.IsInitialized()) {
......@@ -57,11 +64,11 @@ void BroadcastOpHandle::RunImpl() {
return;
}
InitOutputValue(*in_var_handle, out_var_handles);
InitOutputValue(in_var_handle, out_var_handles);
if (platform::is_cpu_place(in_tensor.place())) {
for (auto *out_var_handle : out_var_handles) {
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue;
}
auto &out_p = out_var_handle->place_;
......@@ -118,12 +125,12 @@ void BroadcastOpHandle::RunImpl() {
}
}
if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
if (!out_handle->IsTheSameVar(in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle.scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
in_tensor, in_var_handle.place_,
*(dev_ctxes_.at(in_var_handle.place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
});
......
......@@ -61,7 +61,10 @@ struct BroadcastOpHandle : public OpHandleBase {
protected:
void RunImpl() override;
private:
void BroadcastOneVar(const VarHandle &in_var_handle,
const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes);
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA
......
......@@ -121,6 +121,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
USE_PASS(fuse_elewise_add_act_pass);
USE_PASS(graph_viz_pass);
USE_PASS(multi_batch_merge_pass);
USE_PASS(multi_devices_pass);
USE_PASS(multi_devices_check_pass);
USE_PASS(multi_devices_print_pass);
......@@ -69,6 +69,8 @@ struct BuildStrategy {
bool enable_data_balance_{false};
bool fuse_broadcast_op_{false};
// User normally doesn't need to call this API.
// The PassBuilder allows for more customized insert, remove of passes
// from python side.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/variable_visitor.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
namespace details {
void FusedBroadcastOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second);
if (places_.size() == 1UL) return;
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
WaitInputVarGenerated();
std::vector<const Scope *> var_scopes;
for (auto *s : local_scopes_) {
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
}
size_t place_num = places_.size();
PADDLE_ENFORCE_EQ(in_var_handles.size() * place_num, out_var_handles.size());
for (size_t i = 0; i < in_var_handles.size(); ++i) {
BroadcastOneVar(
*in_var_handles[i],
std::vector<VarHandle *>(out_var_handles.begin() + i * place_num,
out_var_handles.begin() + (i + 1) * place_num),
var_scopes);
}
}
std::string FusedBroadcastOpHandle::Name() const { return "fused_broadcast"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace paddle {
namespace framework {
namespace details {
struct FusedBroadcastOpHandle : public BroadcastOpHandle {
public:
#ifdef PADDLE_WITH_CUDA
FusedBroadcastOpHandle(ir::Node *node,
const std::vector<Scope *> local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap *nccl_ctx)
: BroadcastOpHandle(node, local_scopes, places, nccl_ctx) {}
#else
FusedBroadcastOpHandle(ir::Node* node, const std::vector<Scope*> local_scopes,
const std::vector<platform::Place>& places)
: BroadcastOpHandle(node, local_scopes, places) {}
#endif
std::string Name() const override;
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_pass.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
......@@ -347,7 +348,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
BuildStrategy::GradientScaleStrategy::kCustomized) {
// TODO(paddle-dev): Why is there no input for this op_handle?
auto loss_grad_name = node->Op()->OutputArgumentNames()[0];
CreateScaleLossGradOp(&result, loss_grad_name);
CreateScaleLossGradOp(&result, loss_grad_name, node->outputs[0]);
}
// This assumes the backward generating code will ensure IsScaleLossOp
// is true only for the op that scale the final scalar loss.
......@@ -436,10 +437,14 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
if ((use_gpu &&
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) ||
is_dist_train) {
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
if (strategy_.fuse_broadcast_op_) {
CreateFusedBroadcastOp(&result, bcast_var_name_set);
} else {
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
auto &to_bcast_set = bcast_var_name_set[dev_id];
for (auto &bcast_name : to_bcast_set) {
CreateBroadcastOp(&result, bcast_name, dev_id);
}
}
}
}
......@@ -508,6 +513,44 @@ void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result,
}
}
void MultiDevSSAGraphBuilder::CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const {
#ifdef PADDLE_WITH_CUDA
auto *op_handle = new FusedBroadcastOpHandle(
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
local_scopes_, places_, nccl_ctxs_);
#else
auto *op_handle = new FusedBroadcastOpHandle(
result->CreateEmptyNode("fused_broadcast", ir::Node::Type::kOperation),
local_scopes_, places_);
#endif
result->Get<GraphOps>(kGraphOps).emplace_back(op_handle);
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
SetCommunicationContext(op_handle, p);
}
for (size_t dev_id = 0; dev_id < bcast_varnames.size(); ++dev_id) {
for (auto &p_name : bcast_varnames[dev_id]) {
auto *in =
result->Get<GraphVars>(kGraphVars).at(dev_id).at(p_name).back().get();
op_handle->AddInput(in);
for (size_t out_dev_id = 0; out_dev_id < places_.size(); ++out_dev_id) {
auto &p = places_[out_dev_id];
auto &vars =
result->Get<GraphVars>(kGraphVars).at(out_dev_id).at(p_name);
auto *out_var = new VarHandle(
result->CreateEmptyNode(p_name, ir::Node::Type::kVariable),
vars.size(), out_dev_id, p_name, p);
vars.emplace_back(out_var);
op_handle->AddOutput(out_var);
}
}
}
}
void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result,
ir::Node *node,
int dev_id) const {
......@@ -602,7 +645,8 @@ int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
}
void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
ir::Graph *result, const std::string &loss_grad_name) const {
ir::Graph *result, const std::string &loss_grad_name,
ir::Node *out_var_node) const {
for (size_t i = 0; i < places_.size(); ++i) {
// Insert ScaleCost OpHandle
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(places_[i]);
......@@ -617,10 +661,8 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput(
result, op_handle,
result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable),
places_[i], i);
CreateOpOutput(result, op_handle,
result->CreateVarNode(out_var_node->Var()), places_[i], i);
}
}
......
......@@ -61,7 +61,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
size_t num_places) const;
void CreateScaleLossGradOp(ir::Graph *result,
const std::string &loss_grad_name) const;
const std::string &loss_grad_name,
ir::Node *out_var_node) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const;
......@@ -78,6 +79,10 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateBroadcastOp(ir::Graph *result, const std::string &p_name,
size_t src_dev_id) const;
void CreateFusedBroadcastOp(
ir::Graph *result,
const std::vector<std::unordered_set<std::string>> &bcast_varnames) const;
bool IsSparseGradient(const std::string &og) const;
size_t GetAppropriateDeviceID(
......
......@@ -36,6 +36,7 @@ pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
pass_library(fc_gru_fuse_pass inference)
pass_library(seq_concat_fc_fuse_pass inference)
pass_library(multi_batch_merge_pass base)
pass_library(conv_bn_fuse_pass inference)
pass_library(seqconv_eltadd_relu_fuse_pass inference)
if(WITH_MKLDNN)
......
......@@ -27,14 +27,20 @@ namespace ir {
Graph::Graph(const ProgramDesc &program) : program_(program) {
// Make the nodes id start from 0.
Node::ResetId();
auto var_nodes = InitFromProgram(program_);
ResolveHazard(var_nodes);
}
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program) {
VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars;
// var nodes for each var name, will have multiple versions in SSA
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *var : program.Block(0).AllVars()) {
all_vars.emplace(var->Name(), var);
}
std::map<std::string, std::vector<ir::Node *>> var_nodes;
for (auto *op : program.Block(0).AllOps()) {
ir::Node *node = CreateOpNode(op);
// For input args, reuse the same var name if it was created before.
......@@ -72,7 +78,11 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
var->inputs.push_back(node);
}
}
return std::move(var_nodes);
}
void Graph::ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes) {
/**
* We should handle write after read(WAR) and write after write(WAW) here.
* Because some of the operators of the program can be executed parallelly.
......@@ -91,6 +101,7 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
auto it_old = versions.rbegin();
++it_old;
for (; it_old != versions.rend(); it_new = it_old, ++it_old) {
VLOG(3) << "deal with var: " << (*it_new)->Name();
ir::Node *write_op =
(*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0];
const auto &read_ops = (*it_old)->outputs;
......
......@@ -160,6 +160,12 @@ class Graph {
return nullptr;
}
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private:
// This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) {
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/multi_batch_merge_pass.h"
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
namespace framework {
namespace ir {
static const char kNumRepeats[] = "num_repeats";
typedef std::unordered_map<std::string, std::vector<ir::Node*>> SSAVarList;
ir::Node* SameNameVar(std::unordered_set<ir::Node*> all, ir::Node* target) {
for (auto n : all) {
if (target->IsVar() && target->Name() == n->Name()) {
return n;
}
}
return nullptr;
}
VarDesc CopyVarDesc(VarDesc* var_desc) {
VarDesc repeated_var(var_desc->Name());
// copy other variable attributes
if (var_desc->GetType() != proto::VarType::READER) {
repeated_var.SetType(var_desc->GetType());
repeated_var.SetShape(var_desc->GetShape());
repeated_var.SetDataType(var_desc->GetDataType());
repeated_var.SetLoDLevel(var_desc->GetLoDLevel());
repeated_var.SetPersistable(var_desc->Persistable());
} else {
// TODO(typhoonzero): copy reader var
}
return repeated_var;
}
VarDesc UpdateGradVarDesc(
VarDesc* var_desc, int repeat,
const std::unordered_set<std::string>& grad_names,
const std::unordered_set<std::string>& bn_vars_need_rename) {
if (grad_names.find(var_desc->Name()) != grad_names.end() ||
bn_vars_need_rename.find(var_desc->Name()) != bn_vars_need_rename.end()) {
std::string new_gname =
string::Sprintf("%s.repeat.%d", var_desc->Name(), repeat);
VarDesc repeated_var = CopyVarDesc(var_desc);
repeated_var.SetName(new_gname);
VLOG(3) << "update " << var_desc->Name() << " to repeat " << repeat;
return repeated_var;
}
return *var_desc;
}
std::unique_ptr<Graph> BatchMergePass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
int num_repeats = Get<const int>(kNumRepeats);
std::vector<Node*> forward_backward_ops;
std::vector<Node*> optimize_ops;
std::vector<Node*> lr_ops; // ops other than forward/backward/optimize
std::unordered_set<std::string> grad_names;
std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
auto origin_nodes = graph->ReleaseNodes();
VLOG(3) << "origin nodes count: " << origin_nodes.size();
ir::Graph& result = *graph;
// 1. record op nodes of different roles
for (auto node : nodes) {
if (node->IsVar()) continue;
int op_role = boost::get<int>(node->Op()->GetAttr(
framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
if ((op_role == static_cast<int>(framework::OpRole::kForward)) ||
(op_role & static_cast<int>(framework::OpRole::kBackward)) ||
(op_role & static_cast<int>(framework::OpRole::kLoss))) {
forward_backward_ops.push_back(node);
} else if ((op_role & static_cast<int>(framework::OpRole::kOptimize)) ||
(op_role & static_cast<int>(framework::OpRole::kDist)) ||
(op_role & static_cast<int>(framework::OpRole::kRPC))) {
optimize_ops.push_back(node);
auto op_role_var = node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName());
auto op_role_vars = boost::get<std::vector<std::string>>(op_role_var);
for (size_t i = 0; i < op_role_vars.size(); i += 2) {
grad_names.insert(op_role_vars[i + 1]);
}
} else if (op_role & static_cast<int>(framework::OpRole::kLRSched)) {
lr_ops.push_back(node);
} else { // NOLINT
PADDLE_THROW("Invalid op_role: %d", static_cast<int>(op_role));
}
}
// 2. copy forward backward
ir::Node* prev_repeat_last_op_node = nullptr;
// record origin_grad -> repeated grad list map.
std::map<ir::Node*, std::vector<ir::Node*>> grad_repeated_map;
std::map<std::string, std::vector<ir::Node*>> created;
std::unordered_set<std::string> bn_vars_need_rename;
for (int i = 0; i < num_repeats; ++i) {
std::unordered_set<ir::Node*> copied;
for (size_t node_idx = 0; node_idx < forward_backward_ops.size();
++node_idx) {
auto node = forward_backward_ops[node_idx];
OpDesc repeated_op(*(node->Op()), node->Op()->Block());
// 3. rename grad outputs to current repeat.
for (auto outname : repeated_op.OutputArgumentNames()) {
if (grad_names.find(outname) != grad_names.end()) {
std::string new_gname = string::Sprintf("%s.repeat.%d", outname, i);
repeated_op.RenameOutput(outname, new_gname);
}
}
// 3.5 let batch_norm ops use independent vars, note batch_norm_grad do
// not need this update
if (node->Name() == "batch_norm") {
// NOTE: assume bn op created by layers use save var as output mean and
// variance
std::string new_mean_name =
string::Sprintf("%s.repeat.%d", repeated_op.Input("Mean")[0], i);
std::string new_var_name = string::Sprintf(
"%s.repeat.%d", repeated_op.Input("Variance")[0], i);
bn_vars_need_rename.insert(repeated_op.Input("Mean")[0]);
bn_vars_need_rename.insert(repeated_op.Input("Variance")[0]);
VLOG(3) << "renaming " << repeated_op.Input("Mean")[0] << " to "
<< new_mean_name;
repeated_op.RenameInput(repeated_op.Input("Mean")[0], new_mean_name);
repeated_op.RenameInput(repeated_op.Input("Variance")[0], new_var_name);
repeated_op.RenameOutput(repeated_op.Output("MeanOut")[0],
new_mean_name);
repeated_op.RenameOutput(repeated_op.Output("VarianceOut")[0],
new_var_name);
}
// 3.9 do copy
auto repeated_node = result.CreateOpNode(&repeated_op);
copied.insert(node);
// 4. add deps between repeats
if (node_idx == forward_backward_ops.size() - 1) {
prev_repeat_last_op_node = repeated_node;
}
if (node_idx == 0 && prev_repeat_last_op_node) {
auto* depvar = result.CreateControlDepVar();
prev_repeat_last_op_node->outputs.push_back(depvar);
depvar->inputs.push_back(prev_repeat_last_op_node);
repeated_node->inputs.push_back(depvar);
depvar->outputs.push_back(repeated_node);
}
for (auto in_node : node->inputs) {
if (in_node->IsCtrlVar()) {
continue;
}
ir::Node* var = nullptr;
auto updated_var = UpdateGradVarDesc(in_node->Var(), i, grad_names,
bn_vars_need_rename);
// should be initialized by startup, how to initilize tensor in the
// scope?
if (node->Name() == "batch_norm" &&
bn_vars_need_rename.find(in_node->Name()) !=
bn_vars_need_rename.end()) {
// Create bn mean/variance for each repeat
var = result.CreateVarNode(&updated_var);
created[updated_var.Name()].push_back(var);
copied.insert(in_node);
repeated_node->inputs.push_back(var);
var->outputs.push_back(repeated_node);
continue;
}
// for other ops
if (in_node->inputs.empty() && i > 0) {
// do not copy head vars (inputs, params) in repeats > 0
var = created.at(in_node->Name()).back();
} else {
if (copied.find(in_node) == copied.end()) {
var = result.CreateVarNode(&updated_var);
if (grad_names.find(in_node->Var()->Name()) != grad_names.end()) {
grad_repeated_map[in_node].push_back(var);
}
copied.insert(in_node);
created[updated_var.Name()].push_back(var);
} else {
var = created.at(updated_var.Name()).back();
}
}
repeated_node->inputs.push_back(var);
var->outputs.push_back(repeated_node);
}
for (auto out_node : node->outputs) {
if (out_node->IsCtrlVar()) {
continue;
}
ir::Node* var = nullptr;
auto updated_var = UpdateGradVarDesc(out_node->Var(), i, grad_names,
bn_vars_need_rename);
if (copied.find(out_node) == copied.end()) {
var = result.CreateVarNode(&updated_var);
if (grad_names.find(out_node->Var()->Name()) != grad_names.end()) {
grad_repeated_map[out_node].push_back(var);
}
copied.insert(out_node);
created[updated_var.Name()].push_back(var);
} else {
var = created.at(updated_var.Name()).back();
}
repeated_node->outputs.push_back(var);
var->inputs.push_back(repeated_node);
}
}
}
// 5. create GRAD merge op node
for (auto kv : grad_repeated_map) {
OpDesc sum_op;
sum_op.SetType("sum");
std::vector<std::string> repeated_grad_names;
for (auto r : kv.second) {
repeated_grad_names.push_back(r->Var()->Name());
}
sum_op.SetInput("X", repeated_grad_names);
sum_op.SetOutput("Out", {kv.first->Var()->Name()});
sum_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kBackward));
auto sum_op_node = result.CreateOpNode(&sum_op);
for (auto r : kv.second) {
sum_op_node->inputs.push_back(r);
r->outputs.push_back(sum_op_node);
}
auto sum_out_var_node = result.CreateVarNode(kv.first->Var());
sum_op_node->outputs.push_back(sum_out_var_node);
sum_out_var_node->inputs.push_back(sum_op_node);
created[sum_out_var_node->Name()].push_back(sum_out_var_node);
OpDesc scale_op;
scale_op.SetType("scale");
scale_op.SetInput("X", {sum_out_var_node->Var()->Name()});
// NOTE: inplace scale.
scale_op.SetOutput("Out", {sum_out_var_node->Var()->Name()});
scale_op.SetAttr("scale", static_cast<float>(1.0f / num_repeats));
scale_op.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kBackward));
auto scale_op_node = result.CreateOpNode(&scale_op);
scale_op_node->inputs.push_back(sum_out_var_node);
sum_out_var_node->outputs.push_back(scale_op_node);
auto scale_out_var_node = result.CreateVarNode(sum_out_var_node->Var());
scale_op_node->outputs.push_back(scale_out_var_node);
scale_out_var_node->inputs.push_back(scale_op_node);
created[scale_out_var_node->Name()].push_back(scale_out_var_node);
}
// 6. add optimize ops
{
auto copy_node = [&result, &created](ir::Node* node) {
auto op_node = result.CreateOpNode(node->Op());
// copy op ins/outs
// NOTE: for send/recv ops, the OpDesc uses ctrldepvar to describe
// dependencies, so create those depvars if OpDesc have in/outs.
for (auto in_node : node->inputs) {
if (in_node->IsCtrlVar() && !in_node->Var()) {
continue;
}
ir::Node* var = nullptr;
if (created.find(in_node->Name()) == created.end()) {
var = result.CreateVarNode(in_node->Var());
created[in_node->Name()].push_back(var);
} else {
var = created.at(in_node->Name()).back();
}
op_node->inputs.push_back(var);
var->outputs.push_back(op_node);
}
for (auto out_node : node->outputs) {
if (out_node->IsCtrlVar() && !out_node->Var()) {
continue;
}
auto var = result.CreateVarNode(out_node->Var());
created[out_node->Name()].push_back(var);
op_node->outputs.push_back(var);
var->inputs.push_back(op_node);
}
};
for (auto node : lr_ops) {
copy_node(node);
}
for (auto node : optimize_ops) {
copy_node(node);
}
}
result.ResolveHazard(created);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_batch_merge_pass, paddle::framework::ir::BatchMergePass)
.RequirePassAttr(paddle::framework::ir::kNumRepeats);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
// BatchMergePass is used to copy forward and backward ops for several
// times to run several batches to simulate large batch size training
// as if we have more than 1 GPUs.
// User can define how many batches to run, gradients will be merged
// through those repeats, and then do optimization using merged gradients.
// This pass is extremely useful when doing large batch-size distributed
// sync training, we can simulate even large batch size as if we have more
// GPUs.
class BatchMergePass : public Pass {
public:
virtual ~BatchMergePass() {}
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -109,18 +109,9 @@ ParallelExecutor::ParallelExecutor(
if (member_->local_scopes_.size() != 1 && local_scopes.empty()) {
BCastParamsToDevices(bcast_vars);
}
// Startup Program has been run. All local scopes has correct parameters.
// Startup Program has been run. All local scopes has correct parameters.
// Step 2. Create vars in each scope;
std::vector<details::VariableInfo> var_infos;
for (auto *var : main_program.Block(0).AllVars()) {
var_infos.emplace_back();
var_infos.back().name_ = var->Name();
var_infos.back().type_ = var->GetType();
var_infos.back().persistable_ = var->Persistable();
}
// Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
#ifdef PADDLE_WITH_CUDA
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
......@@ -156,6 +147,17 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, member_->use_cuda_);
#endif
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
std::vector<details::VariableInfo> var_infos;
for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
PADDLE_ENFORCE_EQ(ir::GraphNum(*graph), 1,
......
......@@ -160,7 +160,8 @@ static void PrintTime(int batch_size, int repeat, int num_threads, int tid,
double latency, int epoch = 1) {
LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat
<< ", threads: " << num_threads << ", thread id: " << tid
<< ", latency: " << latency << "ms ======";
<< ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f)
<< " ======";
if (epoch > 1) {
int samples = batch_size * epoch;
LOG(INFO) << "====== sample number: " << samples
......
......@@ -139,6 +139,9 @@ void TestMultiThreadPrediction(
}
for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() {
#ifdef PADDLE_WITH_MKLDNN
platform::set_cur_thread_id(static_cast<int>(tid) + 1);
#endif
// Each thread should have local inputs and outputs.
// The inputs of each thread are all the same.
std::vector<std::vector<PaddleTensor>> inputs_tid = inputs;
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lars_momentum_op.h"
#include "paddle/fluid/operators/momentum_op.h"
namespace paddle {
namespace operators {
class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated");
AddInput("Grad",
"(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter");
AddInput("Velocity",
"(LoDTensor, default LoDTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
.SetDefault(0.001);
AddAttr<float>("lars_weight_decay",
"(float, default 0.0005) LARS weight decay")
.SetDefault(0.0005);
AddComment(R"DOC(
Lars Momentum Optimizer.
This optimizer use LARS (https://arxiv.org/abs/1708.03888) to optimize each
weight using a local learning rate:
$$
local\_lr = \eta *
\frac{\left \| param \right \|}{\left \| grad \right \| + \beta *\left \| param \right \|} \\
velocity = mu * velocity +
local\_lr * (grad + \beta * param) \\
param = param - velocity. \\
$$
Note that we use lars_weight_decay here to decay weights, you may need not to
use L2 regularizers in case of using LARS.
)DOC");
}
};
class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::LarsMomentumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(lars_momentum, ops::LarsMomentumOpKernel<float>,
ops::LarsMomentumOpKernel<double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lars_momentum_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MomentumLarsKernel(const T* p, const T* g, const T* v,
const T* learning_rate, const T mu,
const int64_t num, const T lars_coeff,
const T lars_weight_decay, const T* p_norm,
const T* g_norm, T* p_out, T* v_out) {
T lr = learning_rate[0];
T local_lr = learning_rate[0];
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
i += blockDim.x * gridDim.x) {
if (p_norm[0] > 0 && g_norm[0] > 0) {
local_lr = lr * lars_coeff * p_norm[0] /
(g_norm[0] + lars_weight_decay * p_norm[0]);
}
T v_new = v[i] * mu + local_lr * (g[i] + lars_weight_decay * p[i]);
v_out[i] = v_new;
p_out[i] = p[i] - v_new;
}
}
template <typename DeviceContext, typename T>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto grad = ctx.Input<framework::LoDTensor>("Grad");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
auto* p = param->data<T>();
auto* v = velocity->data<T>();
auto* g = grad->data<T>();
auto* lr = learning_rate->data<T>();
int block = 512;
int grid = (param->numel() + block - 1) / block;
auto eigen_p = framework::EigenVector<T>::Flatten(*param);
auto eigen_g = framework::EigenVector<T>::Flatten(*grad);
// calculate norms using eigein and launch the kernel.
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
auto* p_norm_data = p_norm_t.mutable_data<T>(ctx.GetPlace());
auto* g_norm_data = g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
ep_norm.device(*place) = eigen_p.square().sum().sqrt();
eg_norm.device(*place) = eigen_g.square().sum().sqrt();
MomentumLarsKernel<<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
p, g, v, lr, mu, param->numel(), lars_coeff, lars_weight_decay,
p_norm_data, g_norm_data, p_out, v_out);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
lars_momentum,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::LarsMomentumOpCUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class LarsMomentumOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<framework::LoDTensor>("ParamOut");
auto velocity_out = ctx.Output<framework::LoDTensor>("VelocityOut");
auto param = ctx.Input<framework::LoDTensor>("Param");
auto velocity = ctx.Input<framework::LoDTensor>("Velocity");
auto learning_rate = ctx.Input<framework::LoDTensor>("LearningRate");
auto* grad_var = ctx.InputVar("Grad");
// only support dense for now.
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>());
auto grad = ctx.Input<framework::LoDTensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
velocity_out->mutable_data<T>(ctx.GetPlace());
T mu = static_cast<T>(ctx.Attr<float>("mu"));
T lars_coeff = ctx.Attr<float>("lars_coeff");
T lars_weight_decay = ctx.Attr<float>("lars_weight_decay");
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
auto v_out = framework::EigenVector<T>::Flatten(*velocity_out);
auto p = framework::EigenVector<T>::Flatten(*param);
auto v = framework::EigenVector<T>::Flatten(*velocity);
auto g = framework::EigenVector<T>::Flatten(*grad);
auto* lr = learning_rate->data<T>();
framework::Tensor p_norm_t, g_norm_t;
p_norm_t.Resize({1});
g_norm_t.Resize({1});
p_norm_t.mutable_data<T>(ctx.GetPlace());
g_norm_t.mutable_data<T>(ctx.GetPlace());
auto ep_norm = framework::EigenScalar<T>::From(p_norm_t);
auto eg_norm = framework::EigenScalar<T>::From(g_norm_t);
ep_norm = p.square().sum().sqrt();
eg_norm = g.square().sum().sqrt();
T local_lr = lr[0];
if (ep_norm(0) > 0 && eg_norm(0) > 0) {
local_lr = lr[0] * lars_coeff * ep_norm(0) /
(eg_norm(0) + lars_weight_decay * ep_norm(0));
}
v_out = v * mu + local_lr * (g + lars_weight_decay * p);
p_out = p - v_out;
}
};
} // namespace operators
} // namespace paddle
......@@ -19,54 +19,6 @@ namespace operators {
using Tensor = framework::Tensor;
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Velocity"),
"Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
"Output(VelocityOut) of Momentum should not be null.");
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
}
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
"Learning_rate should be a scalar");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MomentumOpInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
......
......@@ -28,6 +28,54 @@ using framework::SelectedRows;
struct NoNesterov;
struct UseNesterov;
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Velocity"),
"Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("VelocityOut"),
"Output(VelocityOut) of Momentum should not be null.");
auto param_dim = ctx->GetInputDim("Param");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and Grad input of MomentumOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Velocity"),
"Param and Velocity of MomentumOp should have the same dimension.");
}
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
"Learning_rate should be a scalar");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class CPUDenseMomentumFunctor {
private:
......
......@@ -296,38 +296,73 @@ Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobs_() {
p_blobs_.reset(new std::unordered_map<std::string, std::shared_ptr<void>>());
: CPUDeviceContext(place), engine_(mkldnn::engine::cpu, 0), p_blobmap_() {
p_blobmap_.reset(new BlobMap());
p_mutex_.reset(new std::mutex());
}
namespace {
// Current thread's id.
thread_local int cur_thread_id = 0;
}
void set_cur_thread_id(int tid) { cur_thread_id = tid; }
int get_cur_thread_id(void) { return cur_thread_id; }
void MKLDNNDeviceContext::SetBlob(const std::string& name,
std::shared_ptr<void> data) const {
std::unordered_map<std::string, std::shared_ptr<void>>* p;
p = p_blobs_.get();
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<KeyBlob> pBlob = nullptr;
int tid = platform::get_cur_thread_id();
auto it = p->find(name);
std::lock_guard<std::mutex> lock(*p_mutex_.get());
if (it == p->end()) {
(*p)[name] = data; // create new blob
// Find KeyBlob for current thread
auto map_it = pMap->find(tid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
pBlob = std::shared_ptr<KeyBlob>(new KeyBlob());
(*pMap)[tid] = pBlob;
} else {
it->second = data; // set data to existing blob
pBlob = map_it->second;
}
// Find Key in found (or newly created) KeyBlob
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) {
(*pBlob)[name] = data; // create new blob
} else {
key_it->second = data; // set data to existing blob
}
// lock will be automatically released when out of scope
return;
}
std::shared_ptr<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const {
std::unordered_map<std::string, std::shared_ptr<void>>* p;
p = p_blobs_.get();
BlobMap* pMap = p_blobmap_.get();
std::shared_ptr<KeyBlob> pBlob = nullptr;
auto it = p->find(name);
int tid = platform::get_cur_thread_id();
if (it != p->end()) {
return it->second;
}
std::lock_guard<std::mutex> lock(*p_mutex_.get());
// Find KeyBlob for current thread firstly
auto map_it = pMap->find(tid);
if (map_it == pMap->end()) return nullptr;
pBlob = map_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
if (key_it == pBlob->end()) return nullptr;
return nullptr;
// lock will be automatically released when out of scope
return key_it->second;
}
#endif
......
......@@ -176,6 +176,12 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif
#ifdef PADDLE_WITH_MKLDNN
using KeyBlob = std::unordered_map<std::string, std::shared_ptr<void>>;
using BlobMap = std::unordered_map<int, std::shared_ptr<KeyBlob>>;
void set_cur_thread_id(int);
int get_cur_thread_id(void);
class MKLDNNDeviceContext : public CPUDeviceContext {
public:
explicit MKLDNNDeviceContext(CPUPlace place);
......@@ -191,8 +197,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
private:
mkldnn::engine engine_;
std::shared_ptr<std::unordered_map<std::string, std::shared_ptr<void>>>
p_blobs_;
std::shared_ptr<BlobMap> p_blobmap_;
std::shared_ptr<std::mutex> p_mutex_;
};
#endif
......
......@@ -645,9 +645,13 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("set_str", [](ir::Pass &self, const std::string &name,
const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
.def(
"set_str",
[](ir::Pass &self, const std::string &name, const std::string &attr) {
self.Set<std::string>(name, new std::string(attr));
})
.def("set_int", [](ir::Pass &self, const std::string &name, int val) {
self.Set<const int>(name, new int(val));
});
py::class_<ir::PassBuilder, std::shared_ptr<ir::PassBuilder>> pb(
......
......@@ -27,7 +27,7 @@ from . import nn
from . import ops
from . import tensor
from ..initializer import init_on_cpu
from ..framework import default_main_program, Parameter, unique_name
from ..framework import default_main_program, Parameter, unique_name, name_scope
__all__ = [
'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
......@@ -332,14 +332,16 @@ def append_LARS(params_grads, learning_rate, weight_decay):
return grad_norm + weight_decay * param_norm
for param, grad in params_grads:
param_lr = param.optimize_attr['learning_rate']
param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param)))
grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad)))
if type(param_lr) == float and param_lr == 1.0:
decayed_lr = learning_rate * param_norm \
/ _balanced_weight(param_norm, grad_norm)
else:
decayed_lr = learning_rate * param_lr * param_norm \
/ _balanced_weight(param_norm, grad_norm)
# set back param local learning rate
param.optimize_attr['learning_rate'] = decayed_lr
with param.block.program.optimized_guard(
[param, grad]), name_scope("optimizer"):
param_lr = param.optimize_attr['learning_rate']
param_norm = ops.sqrt(nn.reduce_sum(input=ops.square(param)))
grad_norm = ops.sqrt(nn.reduce_sum(input=ops.square(grad)))
if type(param_lr) == float and param_lr == 1.0:
decayed_lr = learning_rate * param_norm \
/ _balanced_weight(param_norm, grad_norm)
else:
decayed_lr = learning_rate * param_lr * param_norm \
/ _balanced_weight(param_norm, grad_norm)
# set back param local learning rate
param.optimize_attr['learning_rate'] = decayed_lr
......@@ -14,6 +14,7 @@
from __future__ import print_function
import re
import sys
from collections import defaultdict
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
from . import framework
......@@ -32,7 +33,8 @@ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer'
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'LarsMomentum',
'LarsMomentumOptimizer'
]
......@@ -105,7 +107,6 @@ class Optimizer(object):
param = param_and_grad[0]
param_lr = param.optimize_attr['learning_rate']
if type(param_lr) == Variable:
print("returns updated param lr ", param_lr)
return param_lr
else:
if param_lr == 1.0:
......@@ -400,6 +401,91 @@ class MomentumOptimizer(Optimizer):
return momentum_op
class LarsMomentumOptimizer(Optimizer):
"""
Momentum optimizer with LARS support
The update equations are as follows:
.. math::
& local\_learning\_rate = learning\_rate * lars\_coeff * \\
\\frac{||param||}{||gradient|| + lars\_weight\_decay * ||param||}
& velocity = mu * velocity + local\_learning\_rate * (gradient + lars\_weight\_decay * param)
& param = param - velocity
Args:
learning_rate (float|Variable): the learning rate used to update parameters. \
Can be a float value or a Variable with one float value as data element.
momentum (float): momentum factor
lars_coeff (float): defines how much we trust the layer to change its weights.
lars_weight_decay (float): weight decay coefficient for decaying using LARS.
regularization: A Regularizer, such as
fluid.regularizer.L2DecayRegularizer.
name: A optional name prefix.
Examples:
.. code-block:: python
optimizer = fluid.optimizer.LarsMomentum(learning_rate=0.2, momentum=0.1, lars_weight_decay=0.001)
optimizer.minimize(cost)
"""
_velocity_acc_str = "velocity"
def __init__(self,
learning_rate,
momentum,
lars_coeff=0.001,
lars_weight_decay=0.0005,
regularization=None,
name=None):
assert learning_rate is not None
assert momentum is not None
super(LarsMomentumOptimizer, self).__init__(
learning_rate=learning_rate,
regularization=regularization,
name=name)
self.type = "lars_momentum"
self._momentum = momentum
self._lars_coeff = float(lars_coeff)
self._lars_weight_decay = float(lars_weight_decay)
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
for p in parameters:
self._add_accumulator(self._velocity_acc_str, p)
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
velocity_acc = self._get_accumulator(self._velocity_acc_str,
param_and_grad[0])
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type,
inputs={
"Param": param_and_grad[0],
"Grad": param_and_grad[1],
"Velocity": velocity_acc,
"LearningRate": self._create_param_lr(param_and_grad)
},
outputs={
"ParamOut": param_and_grad[0],
"VelocityOut": velocity_acc
},
attrs={
"mu": self._momentum,
"lars_coeff": self._lars_coeff,
"lars_weight_decay": self._lars_weight_decay
})
return momentum_op
class AdagradOptimizer(Optimizer):
"""
**Adaptive Gradient Algorithm (Adagrad)**
......@@ -1221,6 +1307,7 @@ DecayedAdagrad = DecayedAdagradOptimizer
Adadelta = AdadeltaOptimizer
RMSProp = RMSPropOptimizer
Ftrl = FtrlOptimizer
LarsMomentum = LarsMomentumOptimizer
class ModelAverage(Optimizer):
......
......@@ -95,7 +95,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
opt.minimize(avg_cost)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import time
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import unittest
from multiprocessing import Process
import os
import signal
from functools import reduce
from test_dist_base import TestDistRunnerBase, runtime_main
from dist_mnist import cnn_model
DTYPE = "float32"
def test_merge_reader(repeat_batch_size=8):
orig_reader = paddle.dataset.mnist.test()
record_batch = []
b = 0
for d in orig_reader():
if b >= repeat_batch_size:
break
record_batch.append(d)
b += 1
while True:
for d in record_batch:
yield d
class TestDistMnist2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
# Evaluator
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size_tensor)
inference_program = fluid.default_main_program().clone()
# Optimization
opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
# Reader
train_reader = paddle.batch(test_merge_reader, batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
opt.minimize(avg_cost)
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
if __name__ == "__main__":
runtime_main(TestDistMnist2x2)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import argparse
import time
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from paddle.fluid import core
import unittest
from multiprocessing import Process
import os
import signal
from functools import reduce
from test_dist_base import TestDistRunnerBase, runtime_main
from dist_mnist import cnn_model
DTYPE = "float32"
paddle.dataset.mnist.fetch()
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
class TestDistMnist2x2(TestDistRunnerBase):
def get_model(self, batch_size=2):
# Input data
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE)
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# Train program
predict = cnn_model(images)
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
# Evaluator
batch_size_tensor = fluid.layers.create_tensor(dtype='int64')
batch_acc = fluid.layers.accuracy(
input=predict, label=label, total=batch_size_tensor)
inference_program = fluid.default_main_program().clone()
# Optimization
opt = fluid.optimizer.LarsMomentumOptimizer(
learning_rate=0.001, momentum=0.9)
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
opt.minimize(avg_cost)
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict
if __name__ == "__main__":
runtime_main(TestDistMnist2x2)
......@@ -26,10 +26,11 @@ import argparse
import paddle.fluid as fluid
RUN_STEP = 10
DEFAULT_BATCH_SIZE = 2
class TestDistRunnerBase(object):
def get_model(self, batch_size=2):
def get_model(self, batch_size=DEFAULT_BATCH_SIZE):
raise NotImplementedError(
"get_model should be implemented by child classes.")
......@@ -48,8 +49,7 @@ class TestDistRunnerBase(object):
return t
def run_pserver(self, args):
self.get_model(batch_size=2)
self.get_model(batch_size=args.batch_size)
# NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
......@@ -65,7 +65,7 @@ class TestDistRunnerBase(object):
def run_trainer(self, args):
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=2)
self.get_model(batch_size=args.batch_size)
if args.mem_opt:
fluid.memory_optimize(fluid.default_main_program(), skip_grads=True)
......@@ -92,6 +92,11 @@ class TestDistRunnerBase(object):
strategy.allow_op_delay = False
build_stra = fluid.BuildStrategy()
if args.batch_merge_repeat > 1:
pass_builder = build_stra._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.batch_merge_repeat)
if args.use_reduce:
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......@@ -145,6 +150,9 @@ def runtime_main(test_class):
parser.add_argument('--use_reduce', action='store_true')
parser.add_argument(
'--use_reader_alloc', action='store_true', required=False, default=True)
parser.add_argument('--batch_size', required=False, type=int, default=2)
parser.add_argument(
'--batch_merge_repeat', required=False, type=int, default=1)
args = parser.parse_args()
......@@ -244,9 +252,18 @@ class TestDistBase(unittest.TestCase):
(e, retry_times))
retry_times -= 1
def _run_local(self, model, envs, check_error_log):
def _run_local(self,
model,
envs,
check_error_log=False,
batch_size=DEFAULT_BATCH_SIZE,
batch_merge_repeat=1):
cmd = "%s %s --role trainer" % (self._python_interp, model)
if batch_size != DEFAULT_BATCH_SIZE:
cmd += " --batch_size %d" % batch_size
if batch_merge_repeat > 1:
cmd += " --batch_merge_repeat %d" % batch_merge_repeat
if self.__use_cuda:
cmd += " --use_cuda"
......
......@@ -26,6 +26,15 @@ class TestDistMnist2x2(TestDistBase):
self.check_with_place("dist_mnist.py", delta=1e-5)
class TestDistMnist2x2Lars(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
def test_se_resnext(self):
self.check_with_place("dist_mnist_lars.py", delta=1e-5)
class TestDistMnist2x2WithMemopt(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_dist_base import TestDistBase
import os
class TestDistMnist2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._use_reduce = False
def test_dist_train(self):
self.check_with_place("dist_mnist_batch_merge.py", delta=1e-5)
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
# TODO(typhoonzero): should auto adapt GPU count on the machine.
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
"FLAGS_cudnn_deterministic": "1",
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "7"
required_envs["GLOG_logtostderr"] = "1"
no_merge_losses = self._run_local(
model_file,
required_envs,
check_error_log=check_error_log,
batch_size=4)
batch_merge_losses = self._run_local(
model_file,
required_envs,
check_error_log=check_error_log,
batch_size=2,
batch_merge_repeat=2)
# Ensure both result have values.
self.assertGreater(len(no_merge_losses), 1)
self.assertEqual(len(no_merge_losses), len(batch_merge_losses))
if __name__ == "__main__":
unittest.main()
......@@ -90,6 +90,45 @@ class TestMomentumOp2(OpTest):
self.check_output()
class TestLarsMomentumOp(OpTest):
def setUp(self):
self.op_type = "lars_momentum"
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
velocity = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.001]).astype("float32")
mu = 0.0001
lars_coeff = 0.001
lars_weight_decay = 0.0005
self.inputs = {
'Param': param,
'Grad': grad,
'Velocity': velocity,
'LearningRate': learning_rate
}
self.attrs = {
'mu': mu,
'lars_coeff': lars_coeff,
'lars_weight_decay': lars_weight_decay
}
pnorm = np.sqrt(np.square(param).sum())
gnorm = np.sqrt(np.square(grad).sum())
local_lr = learning_rate * lars_coeff * pnorm / (
gnorm + lars_weight_decay * param)
velocity_out = mu * velocity + local_lr * (grad + lars_weight_decay *
param)
param_out = param - velocity_out
self.outputs = {'ParamOut': param_out, 'VelocityOut': velocity_out}
def test_check_output(self):
self.check_output()
class TestSparseMomentumOp(unittest.TestCase):
def setUp(self):
self.use_nesterov = False
......
......@@ -1433,7 +1433,7 @@ to transpile() call.")
elif op_type == "adamax":
if varkey in ["Moment", "InfNorm"]:
return param_shape
elif op_type == "momentum":
elif op_type in ["momentum", "lars_momentum"]:
if varkey == "Velocity":
return param_shape
elif op_type == "rmsprop":
......@@ -1444,6 +1444,10 @@ to transpile() call.")
return param_shape
elif op_type == "sgd":
pass
else:
raise ValueError(
"Not supported optimizer for distributed training: %s" %
op_type)
return orig_shape
def _get_varname_parts(self, varname):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册