diff --git a/CMakeLists.txt b/CMakeLists.txt index 997672169fbb4d24028a4529b1a97880b7480503..23bb27e77b9eab0c322a71a8ff570d12d1050377 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,6 +65,7 @@ option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better d option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF) +option(WITH_SYSTEM_BLAS "Use system blas library" OFF) # CMAKE_BUILD_TYPE if(NOT CMAKE_BUILD_TYPE) diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py index 99c9d79b068f5886012fd702d84d0666b9d197b5..a79f25ccc6ace1594f3f331633130eaace5e175b 100644 --- a/benchmark/fluid/args.py +++ b/benchmark/fluid/args.py @@ -125,6 +125,10 @@ def parse_args(): parser.add_argument( '--use_inference_transpiler', action='store_true', - help='If set, uses inference transpiler to optimize the program.') + help='If set, use inference transpiler to optimize the program.') + parser.add_argument( + '--no_random', + action='store_true', + help='If set, keep the random seed and do not shuffle the data.') args = parser.parse_args() return args diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py old mode 100755 new mode 100644 index dcd4d9ea95d816029317a29055b5ca8273ac9f43..94ea7bd6aca7c9595037a2dacc5e36d4c77827e7 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -132,10 +132,6 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, exe.run(startup_prog) # Use inference_transpiler to speedup - if args.use_inference_transpiler: - t = fluid.InferenceTranspiler() - t.transpile(infer_prog, place) - if not args.use_reader_op: feed_var_list = [ var for var in train_prog.global_block().vars.itervalues() @@ -186,6 +182,10 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc, print("Pass: %d, Loss: %f" % (pass_id, np.mean(train_losses))), # evaluation if not args.no_test and batch_acc and not args.use_reader_op: + if args.use_inference_transpiler: + t = fluid.InferenceTranspiler() + t.transpile(infer_prog, place) + pass_test_acc = test(exe, infer_prog, test_reader, feeder, batch_acc) print(", Test Accuracy: %f" % pass_test_acc) @@ -316,6 +316,8 @@ def main(): args = parse_args() print_arguments(args) print_paddle_envs() + if args.no_random: + fluid.default_startup_program().random_seed = 1 # the unique trainer id, starting from 0, needed by trainer # only diff --git a/benchmark/fluid/models/resnet.py b/benchmark/fluid/models/resnet.py index 9ed1093c54a501cc93dbbf9c3651fe70914ce26b..d44a9c07d31cfae9d54ad5949b85c77e60eae258 100644 --- a/benchmark/fluid/models/resnet.py +++ b/benchmark/fluid/models/resnet.py @@ -197,12 +197,12 @@ def get_model(args): optimizer = fluid.optimizer.Momentum(learning_rate=0.01, momentum=0.9) batched_train_reader = paddle.batch( - paddle.reader.shuffle( + train_reader if args.no_random else paddle.reader.shuffle( train_reader, buf_size=5120), batch_size=args.batch_size * args.gpus, drop_last=True) batched_test_reader = paddle.batch( - train_reader, batch_size=args.batch_size, drop_last=True) + test_reader, batch_size=args.batch_size, drop_last=True) return avg_cost, inference_program, optimizer, batched_train_reader,\ batched_test_reader, batch_acc diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index e3b9d94215a858c5c9a34e1b7e97540f1876801d..6ed51c648478efb9784d0c43b169c285e740e0f3 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -83,18 +83,20 @@ else() set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib) endif() -find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS +if(WITH_SYSTEM_BLAS) + find_path(REFERENCE_CBLAS_INCLUDE_DIR NAMES cblas.h PATHS ${REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS}) -find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS + find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS ${REFERENCE_CBLAS_LIB_SEARCH_PATHS}) -if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY) - set(CBLAS_FOUND ON) - set(CBLAS_PROVIDER REFERENCE) - set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR}) - set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY}) - add_definitions(-DPADDLE_USE_REFERENCE_CBLAS) - message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + if(REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY) + set(CBLAS_FOUND ON) + set(CBLAS_PROVIDER REFERENCE) + set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR}) + set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY}) + add_definitions(-DPADDLE_USE_REFERENCE_CBLAS) + message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + endif() endif() if(IOS_USE_VECLIB_FOR_BLAS AND VECLIB_FOUND) diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index efa59fc4a5cf21e885435f564d2a19f892cb534b..7a4bd9183a6dce606d595044852555b04f0e06b2 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -7,18 +7,18 @@ if(NOT WITH_FLUID_ONLY) add_subdirectory(legacy/parameter) if(MOBILE_INFERENCE) - add_subdirectory(capi) + add_subdirectory(legacy/capi) else() add_subdirectory(legacy/pserver) add_subdirectory(trainer) add_subdirectory(scripts) if(WITH_C_API) - add_subdirectory(capi) + add_subdirectory(legacy/capi) endif() if(WITH_SWIG_PY) - add_subdirectory(api) + add_subdirectory(legacy/api) endif() endif() endif() diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 3c73b6cc55c187c3f6e7edd1ce38cc58f4e8413d..4fb4ec38ee965a2790d11378a1ce6befa0ef5a00 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -25,11 +25,12 @@ else() cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) endif() +cc_library(data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope) cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle - scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle) + scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle data_balance_op_handle) cc_library(ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker) diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 64e83acb4dc1995800c4ca3caf81668b24a7c9fe..9c2c845c6efb206fb1ad5150189430b9a6fe9ea3 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -33,6 +33,8 @@ struct BuildStrategy { GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice}; std::string debug_graphviz_path_{""}; + + bool enable_data_balance_{true}; }; } // namespace details diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc new file mode 100644 index 0000000000000000000000000000000000000000..b914851fe0add74f6d85589f4686224b668b8064 --- /dev/null +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/data_balance_op_handle.h" +#include +#include "paddle/fluid/framework/details/container_cast.h" + +namespace paddle { +namespace framework { +namespace details { + +#ifdef PADDLE_WITH_CUDA +DataBalanceOpHandle::DataBalanceOpHandle( + const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs) + : local_scopes_(local_scopes), places_(places) { + if (ctxs) { + for (auto &p : places_) { + this->dev_ctxes_[p] = ctxs->DevCtx(p); + } + } +} +#else +DataBalanceOpHandle::DataBalanceOpHandle( + const std::vector &local_scopes, + const std::vector &places) + : local_scopes_(local_scopes), places_(places) {} +#endif + +std::string DataBalanceOpHandle::Name() const { return "data balance"; } + +std::vector> DataBalanceOpHandle::GetBalancePlan( + const std::vector &device_sizes) { + int device_num = device_sizes.size(); + int total_size = 0; + int empty_num = 0; + std::vector> size_device_vec; + size_device_vec.reserve(device_num); + for (int i = 0; i < device_num; ++i) { + if (device_sizes[i] == 0) { + ++empty_num; + } + total_size += device_sizes[i]; + size_device_vec.push_back({{device_sizes[i], i}}); + } + std::vector> res; + if (empty_num == 0) { + // No need to do data balance. + return res; + } + if (total_size < device_num) { + // No enough data. + PADDLE_THROW("There is no next data."); + } + std::sort(size_device_vec.begin(), size_device_vec.end(), + [](const std::array &a, const std::array &b) { + return a[0] > b[0]; + }); + int expected_device_size = total_size / device_num; + int src_idx = 0; + for (int dst_idx = device_num - empty_num; dst_idx < device_num; ++dst_idx) { + if (size_device_vec[src_idx][0] <= expected_device_size) { + ++src_idx; + PADDLE_ENFORCE_LT( + src_idx, device_num - empty_num, + "In current srategy an empty tensor should not be copy source."); + } + size_device_vec[src_idx][0] -= expected_device_size; + size_device_vec[dst_idx][0] += expected_device_size; + res.push_back({{size_device_vec[src_idx][1], size_device_vec[dst_idx][1], + expected_device_size}}); + } + return res; +} + +void DataBalanceOpHandle::RunImpl() { + if (places_.size() == 1) { + return; + } + auto in_var_handles = DynamicCast(inputs_); + auto out_var_handles = DynamicCast(outputs_); + PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), out_var_handles.size(), + "The NoDummyInputSize and NoDummyOutputSize should be equal."); + int data_num = in_var_handles.size() / places_.size(); + WaitInputVarGenerated(); + std::vector> lod_tensors(data_num); + std::vector device_sizes; + for (int i = 0; i < static_cast(in_var_handles.size()); ++i) { + PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, + "The name of input and output should be equal."); + int place_idx = i / data_num; + int data_idx = i % data_num; + auto *local_scope = + local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get(); + auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name_); + PADDLE_ENFORCE(tensor_var->IsType()); + auto *tensor = tensor_var->GetMutable(); + lod_tensors[data_idx].push_back(tensor); + int ins_size = + tensor->lod().empty() ? tensor->dims()[0] : tensor->NumElements(); + if (data_idx == 0) { + device_sizes.emplace_back(ins_size); + } else { + PADDLE_ENFORCE_EQ( + ins_size, device_sizes.at(place_idx), + "All data on the same device shall have the same batch size."); + } + } + const auto &balance_plan = GetBalancePlan(device_sizes); + + for (const auto &trans : balance_plan) { + for (int data_idx = 0; data_idx < data_num; ++data_idx) { + LoDTensor *src_tensor = lod_tensors[data_idx][trans[0]]; + LoDTensor *dst_tensor = lod_tensors[data_idx][trans[1]]; + int trans_ins_size = trans[2]; + LoD src_lod = src_tensor->lod(); + int src_ins_size = + src_lod.empty() ? src_tensor->dims()[0] : src_tensor->NumElements(); + int cut_point = src_ins_size - trans_ins_size; + if (!src_lod.empty()) { + for (auto &level : src_lod) { + cut_point = level[cut_point]; + } + } + TensorCopySync(src_tensor->Slice(cut_point, src_tensor->dims()[0]), + dst_tensor->place(), dst_tensor); + src_tensor->ShareDataWith(src_tensor->Slice(0, cut_point)); + if (!src_lod.empty()) { + dst_tensor->set_lod(SliceInLevel( + src_lod, 0, src_ins_size - trans_ins_size, src_ins_size)); + src_tensor->set_lod( + SliceInLevel(src_lod, 0, 0, src_ins_size - trans_ins_size)); + } + } + } +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/data_balance_op_handle.h b/paddle/fluid/framework/details/data_balance_op_handle.h new file mode 100644 index 0000000000000000000000000000000000000000..76a407e3610e8bb48facf1f814779f4c23f92d98 --- /dev/null +++ b/paddle/fluid/framework/details/data_balance_op_handle.h @@ -0,0 +1,59 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/details/op_handle_base.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/platform/nccl_helper.h" +#endif + +namespace paddle { +namespace framework { +namespace details { + +struct DataBalanceOpHandle : public OpHandleBase { + public: +#ifdef PADDLE_WITH_CUDA + DataBalanceOpHandle(const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs); +#else + DataBalanceOpHandle(const std::vector &local_scopes, + const std::vector &places); +#endif + + std::string Name() const override; + + bool IsMultiDeviceTransfer() override { return false; }; + + protected: + void RunImpl() override; + + private: + // std::vector<(src_dev_id, dst_dev_id, trans_size)> + std::vector> GetBalancePlan( + const std::vector &batch_size_per_device); + + const std::vector local_scopes_; + const std::vector places_; +}; + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 224e8e1f6efd7a894591ac51c929517cae7539ce..d646c944601e81477787740189d7ac60ae97fa80 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -67,8 +67,8 @@ void FetchOpHandle::RunImpl() { #endif } else { tensors_[i].ShareDataWith(t); - tensors_[i].set_lod(t.lod()); } + tensors_[i].set_lod(t.lod()); } this->WaitAndMergeCPUTensors(); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index cc7b94d0653e34c8ac711a7db7ab6ab1a9ac46a2..46d0c2769cb334f5cb75ae0ef5e48da45448c48f 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" +#include "paddle/fluid/framework/details/data_balance_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" @@ -215,7 +216,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( } else { // This op runs on all devices, and its output may have parameter's // gradients. - CreateComputationalOps(&result, *op, places_.size()); + if (op->Type() == "read" && strategy_.enable_data_balance_) { + op->SetAttr("throw_eof_exp", false); + CreateComputationalOps(&result, *op, places_.size()); + const auto &data_var_names = op->Output("Out"); + InsertDataBalanceOp(&result, data_var_names); + } else { + CreateComputationalOps(&result, *op, places_.size()); + } if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be @@ -360,6 +368,29 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, } } +void MultiDevSSAGraphBuilder::InsertDataBalanceOp( + SSAGraph *result, const std::vector &datas) const { +#ifdef PADDLE_WITH_CUDA + result->ops_.emplace_back( + new DataBalanceOpHandle(local_scopes_, places_, nccl_ctxs_)); +#else + result->ops_.emplace_back(new DataBalanceOpHandle(local_scopes_, places_)); +#endif + auto *op_handle = result->ops_.back().get(); + for (size_t i = 0; i < places_.size(); ++i) { + auto &p = places_[i]; + SetCommunicationContext(op_handle, p); + for (const std::string &d_name : datas) { + auto &vars = result->vars_[i][d_name]; + PADDLE_ENFORCE(!vars.empty()); + op_handle->AddInput(vars.back().get()); + auto var = new VarHandle(vars.size(), i, d_name, p); + vars.emplace_back(var); + op_handle->AddOutput(var); + } + } +} + bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( const std::string &og, std::unordered_set *og_has_been_broadcast) const { @@ -512,7 +543,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result, op_dev_id = GetVarDeviceID(op.InputArgumentNames()[0]); // the variable name which contains .block means it was splited by // split_byref op - // so that we can balance the variable blocks to all the pserver instances. + // so that we can balance the variable blocks to all the pserver + // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && op.InputArgumentNames()[0].find(".block") == std::string::npos) { op_dev_id = GetAppropriateDeviceID(op.InputArgumentNames()); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 0b6347bf51dc1c347073a0fdcf4ddd91865d846d..a964e024885e56693224a6199e00ff30beaa1df4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -101,6 +101,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void InsertAllReduceOp(SSAGraph *result, const std::string &og) const; + void InsertDataBalanceOp(SSAGraph *result, + const std::vector &datas) const; + void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, size_t src_dev_id) const; diff --git a/paddle/fluid/framework/details/op_handle_base.cc b/paddle/fluid/framework/details/op_handle_base.cc index 1f84c3b9e2d7ee9ae51959988fceeb3451b7b3b8..3560fabb424375a770432586fe7c8e51210b3d0c 100644 --- a/paddle/fluid/framework/details/op_handle_base.cc +++ b/paddle/fluid/framework/details/op_handle_base.cc @@ -58,8 +58,10 @@ void OpHandleBase::Run(bool use_cuda) { void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { #ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_NOT_NULL(waited_ctx); if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) { for (auto &dev_ctx : dev_ctxes_) { + PADDLE_ENFORCE_NOT_NULL(dev_ctx.second); dev_ctx.second->Wait(); } } else { diff --git a/paddle/fluid/framework/lod_tensor.cc b/paddle/fluid/framework/lod_tensor.cc index 5373d769a4993bb378b30c3b23885c072b778e5c..cba0064f38f89c1dd27cfac1ddb2339a5ee6c93f 100644 --- a/paddle/fluid/framework/lod_tensor.cc +++ b/paddle/fluid/framework/lod_tensor.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memory.h" @@ -68,9 +69,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) { // only print first ten elements int64_t size = t.numel() < 10 ? t.numel() : 10; for (int64_t i = 0; i < size; ++i) { - if (t.type().hash_code() == typeid(float).hash_code()) { // NOLINT + if (IsType(t.type())) { os << t.data()[i] << " "; - } else if (t.type().hash_code() == typeid(int64_t).hash_code()) { + } else if (IsType(t.type())) { os << t.data()[i] << " "; } else { PADDLE_THROW("LoDTensor data type not in [float, int64_t]"); @@ -89,6 +90,7 @@ std::string LoDToString(const LoD &lod) { LoD SliceInLevel(const LoD &in, size_t level, size_t elem_begin, size_t elem_end) { PADDLE_ENFORCE_LT(level, in.size()); + PADDLE_ENFORCE_LT(elem_begin, elem_end); PADDLE_ENFORCE_LT(elem_end, in[level].size()); LoD res; @@ -384,7 +386,7 @@ void LoDTensor::MergeLoDTensor( LoD new_lod = lod_tensors[0]->lod(); for (size_t i = 1; i < lod_tensors.size(); ++i) { auto *t = lod_tensors[i]; - PADDLE_ENFORCE_EQ(new_type.hash_code(), t->type().hash_code()); + PADDLE_ENFORCE_EQ(new_type, t->type()); PADDLE_ENFORCE_EQ(new_layout, t->layout()); PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0], @@ -392,6 +394,7 @@ void LoDTensor::MergeLoDTensor( new_dim[0] += t->dims()[0]; auto &lod = t->lod(); + PADDLE_ENFORCE_EQ(new_lod.size(), lod.size()); for (size_t j = 0; j < lod.size(); ++j) { auto &sub_lod = new_lod[j]; auto &offset = sub_lod.back(); diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 43ab227a9478707445892c14723801992d0041aa..3314e41cc51d74f87be0e2cd5eba9bb260c16be7 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -76,6 +76,20 @@ class OpRegistry { template struct OpKernelRegistrarFunctor; +template +inline void RegisterKernelClass(const char* op_type, const char* library_type, + Func func) { + std::string library(library_type); + std::string data_layout = "ANYLAYOUT"; + if (library == "MKLDNN") { + data_layout = "MKLDNNLAYOUT"; + } + OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), + StringToDataLayout(data_layout), + StringToLibraryType(library_type)); + OperatorWithKernel::AllOpKernels()[op_type][key] = func; +} + template struct OpKernelRegistrarFunctor { using KERNEL_TYPE = @@ -83,16 +97,10 @@ struct OpKernelRegistrarFunctor { void operator()(const char* op_type, const char* library_type) const { using T = typename KERNEL_TYPE::ELEMENT_TYPE; - std::string library(library_type); - std::string data_layout = "ANYLAYOUT"; - if (library == "MKLDNN") { - data_layout = "MKLDNNLAYOUT"; - } - OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), - StringToDataLayout(data_layout), - StringToLibraryType(library_type)); - OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE); - + RegisterKernelClass( + op_type, library_type, [](const framework::ExecutionContext& ctx) { + KERNEL_TYPE().Compute(ctx); + }); constexpr auto size = std::tuple_size>::value; OpKernelRegistrarFunctor func; @@ -116,6 +124,47 @@ class OpKernelRegistrar : public Registrar { } }; +template +struct OpKernelRegistrarFunctorEx; + +template +class OpKernelRegistrarEx : public Registrar { + public: + explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) { + OpKernelRegistrarFunctorEx + func; + func(op_type, library_type); + } +}; + +template +struct OpKernelRegistrarFunctorEx { + void operator()(const char* op_type, const char* library_type) const {} +}; + +template +struct OpKernelRegistrarFunctorEx { + using Functor = + typename std::tuple_element>::type; + using T = + typename std::tuple_element>::type; + + void operator()(const char* op_type, const char* library_type) const { + RegisterKernelClass(op_type, library_type, Functor()); + + constexpr auto size = + std::tuple_size>::value; + OpKernelRegistrarFunctorEx= size, I + 2, + DataTypeAndKernelType...> + func; + func(op_type, library_type); + } +}; + /** * check if MACRO is used in GLOBAL NAMESPACE. */ @@ -174,6 +223,25 @@ class OpKernelRegistrar : public Registrar { #define REGISTER_OP_CPU_KERNEL(op_type, ...) \ REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) +#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \ + STATIC_ASSERT_GLOBAL_NAMESPACE( \ + __reg_op_kernel_##op_type##_##library_type##__, \ + "REGISTER_OP_KERNEL_EX must be called in global namespace"); \ + static ::paddle::framework::OpKernelRegistrarEx \ + __op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ + #library_type); \ + int TouchOpKernelRegistrar_##op_type##_##library_type() { \ + __op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ + return 0; \ + } + +#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \ + REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \ + __VA_ARGS__) + +#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \ + REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) + /** * Macro to mark what Operator and Kernel * we will use and tell the compiler to diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 4586183d8d206d07f8dcdc000a5ce0bc65d847d5..3cf8e8696d739e3f2894e490161b9fb5b459bc41 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -592,8 +592,7 @@ static void CheckTensorNANOrInf(const std::string& name, if (tensor.memory_size() == 0) { return; } - if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT - tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT + if (!IsType(tensor.type()) && !IsType(tensor.type())) { return; } PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), @@ -652,7 +651,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(expected_kernel_key.place_); } - kernel_iter->second->Compute(ExecutionContext(*this, exec_scope, *dev_ctx)); + kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 1550d5df172f0599e1b42e7f1ccf51ac4dd1e0c3..01d750efbb8aaa35701f6caa7ec103ec21dd529e 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -347,9 +347,9 @@ class OpKernel : public OpKernelBase { class OperatorWithKernel : public OperatorBase { public: + using OpKernelFunc = std::function; using OpKernelMap = - std::unordered_map, - OpKernelType::Hash>; + std::unordered_map; OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, const VariableNameMap& outputs, const AttributeMap& attrs) diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h index 2b646d78f0b23ec3e065c891826856c2341d4ac1..429997c8b89fef7aa164e878095ab3b5c9998e5b 100644 --- a/paddle/fluid/framework/var_type.h +++ b/paddle/fluid/framework/var_type.h @@ -24,18 +24,24 @@ limitations under the License. */ namespace paddle { namespace framework { + +template +bool IsType(const std::type_index& type_index) { + return type_index == std::type_index(typeid(T)); +} + inline proto::VarType::Type ToVarType(std::type_index type) { - if (type.hash_code() == typeid(LoDTensor).hash_code()) { + if (IsType(type)) { return proto::VarType_Type_LOD_TENSOR; - } else if (type.hash_code() == typeid(LoDRankTable).hash_code()) { + } else if (IsType(type)) { return proto::VarType_Type_LOD_RANK_TABLE; - } else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) { + } else if (IsType(type)) { return proto::VarType_Type_LOD_TENSOR_ARRAY; - } else if (type.hash_code() == typeid(SelectedRows).hash_code()) { + } else if (IsType(type)) { return proto::VarType_Type_SELECTED_ROWS; - } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { + } else if (IsType(type)) { return proto::VarType_Type_READER; - } else if (type.hash_code() == typeid(ChannelHolder).hash_code()) { + } else if (IsType(type)) { return proto::VarType_Type_CHANNEL; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index fff1621d3f1bb31cfa04110d1f3cf5dbfe927331..f1064cd20f28092d80d3fd23a862da080b6cc2f3 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include @@ -41,7 +42,7 @@ int AccuDims(Vec &&vec, int size) { return res; } -#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__; +#define SET_TYPE(type__) dic_[std::type_index(typeid(type__))] = #type__; /* * Map typeid to representation. */ @@ -53,14 +54,14 @@ struct DataTypeNamer { template const std::string &repr() const { - auto x = typeid(T).hash_code(); + auto x = std::type_index(typeid(T)); PADDLE_ENFORCE(dic_.count(x), "unknown type for representation"); return dic_.at(x); } - const std::string &repr(size_t &hash) const { // NOLINT - PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation"); - return dic_.at(hash); + const std::string &repr(const std::type_index &type) const { // NOLINT + PADDLE_ENFORCE(dic_.count(type), "unknown type for representation"); + return dic_.at(type); } private: @@ -71,9 +72,7 @@ struct DataTypeNamer { SET_TYPE(void *); } - std::unordered_map - dic_; + std::unordered_map dic_; }; #undef SET_TYPE diff --git a/paddle/fluid/inference/analysis/node.cc b/paddle/fluid/inference/analysis/node.cc index d9d265d225bb77a3f5f83cbd0b8b1c670fb34a31..f2e918f3ff41d9db0c3ec38561015967bed26f4e 100644 --- a/paddle/fluid/inference/analysis/node.cc +++ b/paddle/fluid/inference/analysis/node.cc @@ -23,9 +23,9 @@ namespace analysis { template <> std::string &NodeAttr::As() { if (data_.empty()) { - type_hash_ = typeid(std::string).hash_code(); + type_index_ = std::type_index(typeid(std::string)); } - PADDLE_ENFORCE_EQ(type_hash_, typeid(std::string).hash_code()); + PADDLE_ENFORCE_EQ(type_index_, std::type_index(typeid(std::string))); return data_; } diff --git a/paddle/fluid/inference/analysis/node.h b/paddle/fluid/inference/analysis/node.h index 8ecd1ae730e6ec6775f4a22fdc5dec0e8ca8e2d1..47e524bc5c4a6b1324d5f182053129311487522d 100644 --- a/paddle/fluid/inference/analysis/node.h +++ b/paddle/fluid/inference/analysis/node.h @@ -25,6 +25,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/inference/analysis/device.h" #include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/inference/analysis/helper.h" @@ -57,12 +58,12 @@ struct NodeAttr { // init storage in the first usage. if (data_.empty()) { VLOG(4) << "resize data to " << sizeof(T); - type_hash_ = typeid(T).hash_code(); + type_index_ = std::type_index(typeid(T)); data_.resize(sizeof(T)); } - PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(), + PADDLE_ENFORCE(framework::IsType(type_index_), "type not matched, origin is %s, want %s", - DataTypeNamer::Global().repr(type_hash_), + DataTypeNamer::Global().repr(type_index_), DataTypeNamer::Global().repr()); PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error"); return *reinterpret_cast(&data_[0]); @@ -70,7 +71,7 @@ struct NodeAttr { private: std::string data_; - size_t type_hash_{std::numeric_limits::max()}; + std::type_index type_index_{typeid(NodeAttr)}; }; /* diff --git a/paddle/fluid/operators/conditional_block_op.cc b/paddle/fluid/operators/conditional_block_op.cc index 5984f80d04bdeb232f8e24264ae979725af24ef4..8cc1d94260baccfe28d213b7e021956819e2e79e 100644 --- a/paddle/fluid/operators/conditional_block_op.cc +++ b/paddle/fluid/operators/conditional_block_op.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" namespace paddle { namespace operators { @@ -47,7 +48,7 @@ class ConditionalOp : public framework::OperatorBase { if (!(ips.size() == 1UL && ips[0]->IsInitialized())) { PADDLE_THROW("should have one initialized input as condition"); } - if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() && // NOLINT + if (!(framework::IsType(ips[0]->type()) && // NOLINT ips[0]->numel() == 1)) { PADDLE_THROW( "condition input's data type should be bool, " diff --git a/paddle/fluid/operators/fc_mkldnn_op.cc b/paddle/fluid/operators/fc_mkldnn_op.cc index 847b7b0c12e1679501dbe83d578b23ca2aef3e9e..99fa659a351249a4a93f71700e1c646465861aba 100644 --- a/paddle/fluid/operators/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/fc_mkldnn_op.cc @@ -115,6 +115,7 @@ class MKLDNNMemory { template class FCMKLDNNOpKernel : public paddle::framework::OpKernel { + public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace."); diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index db7634918a5179a61304315ecd08350d23fb4642..cceac402951ae6bf3fe0b4c96af5b7ce9ca1ba0e 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -16,6 +16,7 @@ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/variable.h" namespace paddle { @@ -62,7 +63,7 @@ struct Formater { } } void PrintDtype() { - if (dtype.hash_code() != typeid(const char).hash_code()) { + if (!framework::IsType(dtype)) { CLOG << "\tdtype: " << dtype.name() << std::endl; } } @@ -83,15 +84,15 @@ struct Formater { void PrintData(size_t size) { PADDLE_ENFORCE_NOT_NULL(data); // print float - if (dtype.hash_code() == typeid(const float).hash_code()) { + if (framework::IsType(dtype)) { Display(size); - } else if (dtype.hash_code() == typeid(const double).hash_code()) { + } else if (framework::IsType(dtype)) { Display(size); - } else if (dtype.hash_code() == typeid(const int).hash_code()) { + } else if (framework::IsType(dtype)) { Display(size); - } else if (dtype.hash_code() == typeid(const int64_t).hash_code()) { + } else if (framework::IsType(dtype)) { Display(size); - } else if (dtype.hash_code() == typeid(const bool).hash_code()) { + } else if (framework::IsType(dtype)) { Display(size); } else { CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl; diff --git a/paddle/fluid/operators/read_op.cc b/paddle/fluid/operators/read_op.cc index 72a27d43584d55cd0859c63577ae85ff0f5fdfa8..60e4eb757668e1482090f02aea529aaad3a674d8 100644 --- a/paddle/fluid/operators/read_op.cc +++ b/paddle/fluid/operators/read_op.cc @@ -66,9 +66,19 @@ class ReadOp : public framework::OperatorBase { std::vector out_arg_names = Outputs("Out"); std::vector ins; reader->ReadNext(&ins); - PADDLE_ENFORCE(!ins.empty(), "There is no next data."); + if (ins.empty()) { + if (Attr("throw_eof_exp")) { + PADDLE_THROW("There is no next data."); + } else { + ins.resize(out_arg_names.size()); + for (auto& tensor : ins) { + // data type is not important for subsequent DataBalanceOpHandle + tensor.mutable_data(framework::make_ddim({0}), dev_place); + } + } + } PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); - for (size_t i = 0; i < ins.size(); ++i) { + for (size_t i = 0; i < out_arg_names.size(); ++i) { auto* out = scope.FindVar(out_arg_names[i])->GetMutable(); out->ShareDataWith(ins[i]); @@ -82,6 +92,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("Reader", "(ReaderHolder) The executed reader."); AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); + AddAttr("throw_eof_exp", + "If set true, an exception will be thrown when the Reader " + "yields empty (which means there is no next data).") + .SetDefault(true); AddComment(R"DOC( Read Operator diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 7f743f577fbcdaf6f62e01031e25ef09a842c2e9..918f3be533d51367eade5f5108ad2eab954a9303 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -12,14 +12,108 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/reshape_op.h" - #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { +class ReshapeOp : public framework::OperatorWithKernel { + public: + ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ReshapeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ReshapeOp should not be null."); + + const std::vector &shape = ctx->Attrs().Get>("shape"); + PADDLE_ENFORCE(!shape.empty(), + "The shape information must be set by Attr(shape)."); + + if (ctx->HasInput("Shape") && ctx->IsRuntime()) { + // If true, set the shape of Output(Out) according to Input(Shape) in + // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel. + ctx->ShareLoD("X", /*->*/ "Out"); + return; + } + + auto x_dims = ctx->GetInputDim("X"); + auto out_dims = ValidateShape(shape, x_dims); + ctx->SetOutputDim("Out", out_dims); + if (x_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + + static framework::DDim ValidateShape(const std::vector shape, + const framework::DDim &in_dims) { + const int64_t in_size = framework::product(in_dims); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PADDLE_ENFORCE( + unk_dim_idx == -1, + "Only one input dimension of Attr(shape) can be unknown."); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PADDLE_ENFORCE( + static_cast(i) < in_dims.size(), + "The index of dimension to copy from input shape must be less " + "than the size of input shape."); + } else { + PADDLE_ENFORCE( + shape[i] > 0, + "Each input dimension of Attr(shape) must not be negtive except " + "one unknown dimension."); + } + + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = + (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } + + if (unk_dim_idx != -1) { + if (in_size > 0) { + // in_size < 0 and is un-determinate in compile time, skip the check, + // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, in_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -in_size / capacity; + PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, + "Invalid shape is given."); + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); + } + return framework::make_ddim(output_shape); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -107,19 +201,93 @@ class ReshapeGradOp : public framework::OperatorWithKernel { } }; +class ReshapeKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *out = ctx.Output("Out"); + auto *in = ctx.Input("X"); + + auto *shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; + + framework::DDim out_dims = out->dims(); + + if (shape_tensor) { + auto *shape_data = shape_tensor->data(); + framework::Tensor cpu_shape_tensor; + if (platform::is_gpu_place(ctx.GetPlace())) { + TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); + shape_data = cpu_shape_tensor.data(); + } + auto shape = + std::vector(shape_data, shape_data + shape_tensor->numel()); + out_dims = ReshapeOp::ValidateShape(shape, in->dims()); + } + if (!in->lod().empty()) { + PADDLE_ENFORCE_EQ( + out_dims[0], in->dims()[0], + "Reshape operator cannot reshape an input sequence batch " + "into an output sequence batch that has a different " + "number of time steps. Please consider using " + "sequence_reshape op."); + } + + bool inplace = ctx.Attr("inplace"); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace(), in->type()); + framework::TensorCopySync(*in, ctx.GetPlace(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } + } +}; + +class ReshapeGradKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *d_out = ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + bool inplace = ctx.Attr("inplace"); + + auto in_dims = d_x->dims(); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + ctx.device_context().Wait(); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp); -REGISTER_OP_CPU_KERNEL(reshape, ops::ReshapeKernel, - ops::ReshapeKernel, - ops::ReshapeKernel, - ops::ReshapeKernel); -REGISTER_OP_CPU_KERNEL(reshape_grad, ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel, - ops::ReshapeGradKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double, + ops::ReshapeKernel, int, ops::ReshapeKernel, + int64_t, ops::ReshapeKernel); +REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel, + double, ops::ReshapeGradKernel, int, + ops::ReshapeGradKernel, int64_t, + ops::ReshapeGradKernel); +#endif diff --git a/paddle/fluid/operators/reshape_op.cu b/paddle/fluid/operators/reshape_op.cu deleted file mode 100644 index c628c634e2bc9ae260948a6e7ccf786cbd6c5c3c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reshape_op.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/reshape_op.h" -using CUDA = paddle::platform::CUDADeviceContext; - -REGISTER_OP_CUDA_KERNEL(reshape, paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel, - paddle::operators::ReshapeKernel); -REGISTER_OP_CUDA_KERNEL(reshape_grad, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel, - paddle::operators::ReshapeGradKernel); diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h deleted file mode 100644 index 3dd8c7c11eca241e747bfa129962032d882ce44c..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/reshape_op.h +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class ReshapeOp : public framework::OperatorWithKernel { - public: - ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of ReshapeOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of ReshapeOp should not be null."); - - const std::vector &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE(!shape.empty(), - "The shape information must be set by Attr(shape)."); - - if (ctx->HasInput("Shape") && ctx->IsRuntime()) { - // If true, set the shape of Output(Out) according to Input(Shape) in - // ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel. - ctx->ShareLoD("X", /*->*/ "Out"); - return; - } - - auto x_dims = ctx->GetInputDim("X"); - auto out_dims = ValidateShape(shape, x_dims); - ctx->SetOutputDim("Out", out_dims); - if (x_dims[0] == out_dims[0]) { - // Only pass LoD when the first dimension of output and Input(X) - // are the same. - ctx->ShareLoD("X", /*->*/ "Out"); - } - } - - static framework::DDim ValidateShape(const std::vector shape, - const framework::DDim &in_dims) { - const int64_t in_size = framework::product(in_dims); - // only one dimension can be set to -1, whose size will be automatically - // infered. - const int64_t unk_dim_val = -1; - const int64_t copy_dim_val = 0; - - std::vector output_shape(shape.size(), 0); - int64_t capacity = 1; - int unk_dim_idx = -1; - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == unk_dim_val) { - PADDLE_ENFORCE( - unk_dim_idx == -1, - "Only one input dimension of Attr(shape) can be unknown."); - unk_dim_idx = i; - } else if (shape[i] == copy_dim_val) { - PADDLE_ENFORCE( - static_cast(i) < in_dims.size(), - "The index of dimension to copy from input shape must be less " - "than the size of input shape."); - } else { - PADDLE_ENFORCE( - shape[i] > 0, - "Each input dimension of Attr(shape) must not be negtive except " - "one unknown dimension."); - } - - capacity *= (shape[i] ? shape[i] : in_dims[i]); - output_shape[i] = - (shape[i] ? static_cast(shape[i]) : in_dims[i]); - } - - if (unk_dim_idx != -1) { - if (in_size > 0) { - // in_size < 0 and is un-determinate in compile time, skip the check, - // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], - // capacity = -24, in_size = -8, output_shape[0] = 0 - // the following check will fail. - output_shape[unk_dim_idx] = -in_size / capacity; - PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size, - "Invalid shape is given."); - } else { - output_shape[unk_dim_idx] = -1; - } - } else { - PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given."); - } - return framework::make_ddim(output_shape); - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); - } -}; - -template -class ReshapeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const { - auto *out = ctx.Output("Out"); - auto *in = ctx.Input("X"); - - auto *shape_tensor = ctx.HasInput("Shape") - ? ctx.Input("Shape") - : nullptr; - - framework::DDim out_dims = out->dims(); - - if (shape_tensor) { - auto *shape_data = shape_tensor->data(); - framework::Tensor cpu_shape_tensor; - if (platform::is_gpu_place(ctx.GetPlace())) { - TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor); - shape_data = cpu_shape_tensor.data(); - } - auto shape = - std::vector(shape_data, shape_data + shape_tensor->numel()); - out_dims = ReshapeOp::ValidateShape(shape, in->dims()); - } - if (!in->lod().empty()) { - PADDLE_ENFORCE_EQ( - out_dims[0], in->dims()[0], - "Reshape operator cannot reshape an input sequence batch " - "into an output sequence batch that has a different " - "number of time steps. Please consider using " - "sequence_reshape op."); - } - - bool inplace = ctx.Attr("inplace"); - out->Resize(out_dims); - if (!inplace) { - out->mutable_data(ctx.GetPlace()); - framework::TensorCopySync(*in, ctx.GetPlace(), out); - out->Resize(out_dims); - } else { - out->ShareDataWith(*in); - out->Resize(out_dims); - } - } -}; - -template -class ReshapeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const { - auto *d_out = ctx.Input(framework::GradVarName("Out")); - auto *d_x = ctx.Output(framework::GradVarName("X")); - - d_x->mutable_data(ctx.GetPlace()); - bool inplace = ctx.Attr("inplace"); - - auto in_dims = d_x->dims(); - if (!inplace) { - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - ctx.device_context().Wait(); - d_x->Resize(in_dims); - } else { - d_x->ShareDataWith(*d_out); - d_x->Resize(in_dims); - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index f440058e8db2024f5c8a0129db3af87a80d6e551..733157ea05ed39434b9a750e3a94ea548f512ce6 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/detail/safe_ref.h" namespace paddle { @@ -135,15 +136,14 @@ class WhileGradOp : public framework::OperatorBase { auto &og_inside = detail::Ref(cur_scope.Var(inside_og_name), "Cannot find inside gradient %s", inside_og_name); - if (og_outside.Type().hash_code() == - typeid(framework::LoDTensor).hash_code()) { + if (framework::IsType(og_outside.Type())) { auto &outside_tensor = og_outside.Get(); auto &inside_tensor = detail::Ref(og_inside.GetMutable()); inside_tensor.set_lod(outside_tensor.lod()); inside_tensor.ShareDataWith(outside_tensor); - } else if (og_outside.Type().hash_code() == - typeid(framework::LoDTensorArray).hash_code()) { + } else if (framework::IsType( + og_outside.Type())) { auto &outside_array = og_outside.Get(); auto &inside_array = detail::Ref(og_inside.GetMutable()); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 36d080996831d4ad90d92baeafbe964693e2332a..9fc647a7d2a2bdfbaeeb91b00b4183f5c80b5aba 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle. [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](BuildStrategy &self, const std::string &path) { self.debug_graphviz_path_ = path; - }); + }) + .def_property( + "enable_data_balance", + [](const BuildStrategy &self) { return self.enable_data_balance_; }, + [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }); pe.def(py::init &, const std::unordered_set &, diff --git a/paddle/api/Arguments.cpp b/paddle/legacy/api/Arguments.cpp similarity index 100% rename from paddle/api/Arguments.cpp rename to paddle/legacy/api/Arguments.cpp diff --git a/paddle/api/CMakeLists.txt b/paddle/legacy/api/CMakeLists.txt similarity index 100% rename from paddle/api/CMakeLists.txt rename to paddle/legacy/api/CMakeLists.txt diff --git a/paddle/api/ConfigParser.cpp b/paddle/legacy/api/ConfigParser.cpp similarity index 100% rename from paddle/api/ConfigParser.cpp rename to paddle/legacy/api/ConfigParser.cpp diff --git a/paddle/api/Evaluator.cpp b/paddle/legacy/api/Evaluator.cpp similarity index 100% rename from paddle/api/Evaluator.cpp rename to paddle/legacy/api/Evaluator.cpp diff --git a/paddle/api/GradientMachine.cpp b/paddle/legacy/api/GradientMachine.cpp similarity index 100% rename from paddle/api/GradientMachine.cpp rename to paddle/legacy/api/GradientMachine.cpp diff --git a/paddle/api/Internal.h b/paddle/legacy/api/Internal.h similarity index 100% rename from paddle/api/Internal.h rename to paddle/legacy/api/Internal.h diff --git a/paddle/api/Matrix.cpp b/paddle/legacy/api/Matrix.cpp similarity index 100% rename from paddle/api/Matrix.cpp rename to paddle/legacy/api/Matrix.cpp diff --git a/paddle/api/Paddle.i b/paddle/legacy/api/Paddle.i similarity index 98% rename from paddle/api/Paddle.i rename to paddle/legacy/api/Paddle.i index 3237e73745dca58bed923b20851f0f0039a3487c..e6165fb10689ae3183d094a0be340aae5644c1cf 100644 --- a/paddle/api/Paddle.i +++ b/paddle/legacy/api/Paddle.i @@ -2,7 +2,7 @@ %include "std_string.i" %{ #define SWIG_FILE_WITH_INIT -#include "api/PaddleAPI.h" +#include "legacy/api/PaddleAPI.h" %} %include "exception.i" @@ -199,4 +199,4 @@ namespace std { %ignore OptimizationConfigPrivate; %ignore ParameterTraverseCallbackPrivate; %include "utils/GlobalConstants.h" -%include "api/PaddleAPI.h" +%include "legacy/api/PaddleAPI.h" diff --git a/paddle/api/PaddleAPI.h b/paddle/legacy/api/PaddleAPI.h similarity index 100% rename from paddle/api/PaddleAPI.h rename to paddle/legacy/api/PaddleAPI.h diff --git a/paddle/api/PaddleAPIPrivate.h b/paddle/legacy/api/PaddleAPIPrivate.h similarity index 100% rename from paddle/api/PaddleAPIPrivate.h rename to paddle/legacy/api/PaddleAPIPrivate.h diff --git a/paddle/api/Parameter.cpp b/paddle/legacy/api/Parameter.cpp similarity index 100% rename from paddle/api/Parameter.cpp rename to paddle/legacy/api/Parameter.cpp diff --git a/paddle/api/ParameterOptimizer.cpp b/paddle/legacy/api/ParameterOptimizer.cpp similarity index 100% rename from paddle/api/ParameterOptimizer.cpp rename to paddle/legacy/api/ParameterOptimizer.cpp diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/legacy/api/ParameterUpdater.cpp similarity index 100% rename from paddle/api/ParameterUpdater.cpp rename to paddle/legacy/api/ParameterUpdater.cpp diff --git a/paddle/api/SequenceGenerator.cpp b/paddle/legacy/api/SequenceGenerator.cpp similarity index 100% rename from paddle/api/SequenceGenerator.cpp rename to paddle/legacy/api/SequenceGenerator.cpp diff --git a/paddle/api/Trainer.cpp b/paddle/legacy/api/Trainer.cpp similarity index 100% rename from paddle/api/Trainer.cpp rename to paddle/legacy/api/Trainer.cpp diff --git a/paddle/api/Util.cpp b/paddle/legacy/api/Util.cpp similarity index 100% rename from paddle/api/Util.cpp rename to paddle/legacy/api/Util.cpp diff --git a/paddle/api/Vector.cpp b/paddle/legacy/api/Vector.cpp similarity index 100% rename from paddle/api/Vector.cpp rename to paddle/legacy/api/Vector.cpp diff --git a/paddle/api/__init__.py b/paddle/legacy/api/__init__.py similarity index 100% rename from paddle/api/__init__.py rename to paddle/legacy/api/__init__.py diff --git a/paddle/api/numpy.i b/paddle/legacy/api/numpy.i similarity index 100% rename from paddle/api/numpy.i rename to paddle/legacy/api/numpy.i diff --git a/paddle/api/test/.gitignore b/paddle/legacy/api/test/.gitignore similarity index 100% rename from paddle/api/test/.gitignore rename to paddle/legacy/api/test/.gitignore diff --git a/paddle/api/test/CMakeLists.txt b/paddle/legacy/api/test/CMakeLists.txt similarity index 100% rename from paddle/api/test/CMakeLists.txt rename to paddle/legacy/api/test/CMakeLists.txt diff --git a/paddle/api/test/testArguments.py b/paddle/legacy/api/test/testArguments.py similarity index 100% rename from paddle/api/test/testArguments.py rename to paddle/legacy/api/test/testArguments.py diff --git a/paddle/api/test/testGradientMachine.py b/paddle/legacy/api/test/testGradientMachine.py similarity index 100% rename from paddle/api/test/testGradientMachine.py rename to paddle/legacy/api/test/testGradientMachine.py diff --git a/paddle/api/test/testMatrix.py b/paddle/legacy/api/test/testMatrix.py similarity index 100% rename from paddle/api/test/testMatrix.py rename to paddle/legacy/api/test/testMatrix.py diff --git a/paddle/api/test/testTrain.py b/paddle/legacy/api/test/testTrain.py similarity index 100% rename from paddle/api/test/testTrain.py rename to paddle/legacy/api/test/testTrain.py diff --git a/paddle/api/test/testTrainConfig.py b/paddle/legacy/api/test/testTrainConfig.py similarity index 100% rename from paddle/api/test/testTrainConfig.py rename to paddle/legacy/api/test/testTrainConfig.py diff --git a/paddle/api/test/testTrainer.py b/paddle/legacy/api/test/testTrainer.py similarity index 100% rename from paddle/api/test/testTrainer.py rename to paddle/legacy/api/test/testTrainer.py diff --git a/paddle/api/test/testVector.py b/paddle/legacy/api/test/testVector.py similarity index 100% rename from paddle/api/test/testVector.py rename to paddle/legacy/api/test/testVector.py diff --git a/paddle/api/test/util.py b/paddle/legacy/api/test/util.py similarity index 100% rename from paddle/api/test/util.py rename to paddle/legacy/api/test/util.py diff --git a/paddle/capi/Arguments.cpp b/paddle/legacy/capi/Arguments.cpp similarity index 100% rename from paddle/capi/Arguments.cpp rename to paddle/legacy/capi/Arguments.cpp diff --git a/paddle/capi/CMakeLists.txt b/paddle/legacy/capi/CMakeLists.txt similarity index 100% rename from paddle/capi/CMakeLists.txt rename to paddle/legacy/capi/CMakeLists.txt diff --git a/paddle/capi/Main.cpp b/paddle/legacy/capi/Main.cpp similarity index 100% rename from paddle/capi/Main.cpp rename to paddle/legacy/capi/Main.cpp diff --git a/paddle/capi/Matrix.cpp b/paddle/legacy/capi/Matrix.cpp similarity index 100% rename from paddle/capi/Matrix.cpp rename to paddle/legacy/capi/Matrix.cpp diff --git a/paddle/capi/Vector.cpp b/paddle/legacy/capi/Vector.cpp similarity index 100% rename from paddle/capi/Vector.cpp rename to paddle/legacy/capi/Vector.cpp diff --git a/paddle/capi/arguments.h b/paddle/legacy/capi/arguments.h similarity index 100% rename from paddle/capi/arguments.h rename to paddle/legacy/capi/arguments.h diff --git a/paddle/capi/capi.h b/paddle/legacy/capi/capi.h similarity index 100% rename from paddle/capi/capi.h rename to paddle/legacy/capi/capi.h diff --git a/paddle/capi/capi_private.h b/paddle/legacy/capi/capi_private.h similarity index 100% rename from paddle/capi/capi_private.h rename to paddle/legacy/capi/capi_private.h diff --git a/paddle/capi/config.h.in b/paddle/legacy/capi/config.h.in similarity index 100% rename from paddle/capi/config.h.in rename to paddle/legacy/capi/config.h.in diff --git a/paddle/capi/error.cpp b/paddle/legacy/capi/error.cpp similarity index 100% rename from paddle/capi/error.cpp rename to paddle/legacy/capi/error.cpp diff --git a/paddle/capi/error.h b/paddle/legacy/capi/error.h similarity index 100% rename from paddle/capi/error.h rename to paddle/legacy/capi/error.h diff --git a/paddle/capi/examples/.gitignore b/paddle/legacy/capi/examples/.gitignore similarity index 100% rename from paddle/capi/examples/.gitignore rename to paddle/legacy/capi/examples/.gitignore diff --git a/paddle/capi/examples/README.md b/paddle/legacy/capi/examples/README.md similarity index 100% rename from paddle/capi/examples/README.md rename to paddle/legacy/capi/examples/README.md diff --git a/paddle/capi/examples/model_inference/README.md b/paddle/legacy/capi/examples/model_inference/README.md similarity index 100% rename from paddle/capi/examples/model_inference/README.md rename to paddle/legacy/capi/examples/model_inference/README.md diff --git a/paddle/capi/examples/model_inference/common/common.h b/paddle/legacy/capi/examples/model_inference/common/common.h similarity index 100% rename from paddle/capi/examples/model_inference/common/common.h rename to paddle/legacy/capi/examples/model_inference/common/common.h diff --git a/paddle/capi/examples/model_inference/dense/CMakeLists.txt b/paddle/legacy/capi/examples/model_inference/dense/CMakeLists.txt similarity index 100% rename from paddle/capi/examples/model_inference/dense/CMakeLists.txt rename to paddle/legacy/capi/examples/model_inference/dense/CMakeLists.txt diff --git a/paddle/capi/examples/model_inference/dense/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/dense/convert_protobin.sh similarity index 100% rename from paddle/capi/examples/model_inference/dense/convert_protobin.sh rename to paddle/legacy/capi/examples/model_inference/dense/convert_protobin.sh diff --git a/paddle/capi/examples/model_inference/dense/main.c b/paddle/legacy/capi/examples/model_inference/dense/main.c similarity index 100% rename from paddle/capi/examples/model_inference/dense/main.c rename to paddle/legacy/capi/examples/model_inference/dense/main.c diff --git a/paddle/capi/examples/model_inference/dense/merge_v2_model.py b/paddle/legacy/capi/examples/model_inference/dense/merge_v2_model.py similarity index 100% rename from paddle/capi/examples/model_inference/dense/merge_v2_model.py rename to paddle/legacy/capi/examples/model_inference/dense/merge_v2_model.py diff --git a/paddle/capi/examples/model_inference/dense/mnist_v2.py b/paddle/legacy/capi/examples/model_inference/dense/mnist_v2.py similarity index 100% rename from paddle/capi/examples/model_inference/dense/mnist_v2.py rename to paddle/legacy/capi/examples/model_inference/dense/mnist_v2.py diff --git a/paddle/capi/examples/model_inference/dense/trainer_config.py b/paddle/legacy/capi/examples/model_inference/dense/trainer_config.py similarity index 100% rename from paddle/capi/examples/model_inference/dense/trainer_config.py rename to paddle/legacy/capi/examples/model_inference/dense/trainer_config.py diff --git a/paddle/capi/examples/model_inference/multi_thread/.gitignore b/paddle/legacy/capi/examples/model_inference/multi_thread/.gitignore similarity index 100% rename from paddle/capi/examples/model_inference/multi_thread/.gitignore rename to paddle/legacy/capi/examples/model_inference/multi_thread/.gitignore diff --git a/paddle/capi/examples/model_inference/multi_thread/CMakeLists.txt b/paddle/legacy/capi/examples/model_inference/multi_thread/CMakeLists.txt similarity index 100% rename from paddle/capi/examples/model_inference/multi_thread/CMakeLists.txt rename to paddle/legacy/capi/examples/model_inference/multi_thread/CMakeLists.txt diff --git a/paddle/capi/examples/model_inference/multi_thread/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh similarity index 100% rename from paddle/capi/examples/model_inference/multi_thread/convert_protobin.sh rename to paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh diff --git a/paddle/capi/examples/model_inference/multi_thread/main.c b/paddle/legacy/capi/examples/model_inference/multi_thread/main.c similarity index 100% rename from paddle/capi/examples/model_inference/multi_thread/main.c rename to paddle/legacy/capi/examples/model_inference/multi_thread/main.c diff --git a/paddle/capi/examples/model_inference/multi_thread/main_gpu.c b/paddle/legacy/capi/examples/model_inference/multi_thread/main_gpu.c similarity index 100% rename from paddle/capi/examples/model_inference/multi_thread/main_gpu.c rename to paddle/legacy/capi/examples/model_inference/multi_thread/main_gpu.c diff --git a/paddle/capi/examples/model_inference/multi_thread/trainer_config.py b/paddle/legacy/capi/examples/model_inference/multi_thread/trainer_config.py similarity index 100% rename from paddle/capi/examples/model_inference/multi_thread/trainer_config.py rename to paddle/legacy/capi/examples/model_inference/multi_thread/trainer_config.py diff --git a/paddle/capi/examples/model_inference/sequence/.gitignore b/paddle/legacy/capi/examples/model_inference/sequence/.gitignore similarity index 100% rename from paddle/capi/examples/model_inference/sequence/.gitignore rename to paddle/legacy/capi/examples/model_inference/sequence/.gitignore diff --git a/paddle/capi/examples/model_inference/sequence/CMakeLists.txt b/paddle/legacy/capi/examples/model_inference/sequence/CMakeLists.txt similarity index 100% rename from paddle/capi/examples/model_inference/sequence/CMakeLists.txt rename to paddle/legacy/capi/examples/model_inference/sequence/CMakeLists.txt diff --git a/paddle/capi/examples/model_inference/sequence/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh similarity index 100% rename from paddle/capi/examples/model_inference/sequence/convert_protobin.sh rename to paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh diff --git a/paddle/capi/examples/model_inference/sequence/main.c b/paddle/legacy/capi/examples/model_inference/sequence/main.c similarity index 100% rename from paddle/capi/examples/model_inference/sequence/main.c rename to paddle/legacy/capi/examples/model_inference/sequence/main.c diff --git a/paddle/capi/examples/model_inference/sequence/trainer_config.py b/paddle/legacy/capi/examples/model_inference/sequence/trainer_config.py similarity index 100% rename from paddle/capi/examples/model_inference/sequence/trainer_config.py rename to paddle/legacy/capi/examples/model_inference/sequence/trainer_config.py diff --git a/paddle/capi/examples/model_inference/sparse_binary/.gitignore b/paddle/legacy/capi/examples/model_inference/sparse_binary/.gitignore similarity index 100% rename from paddle/capi/examples/model_inference/sparse_binary/.gitignore rename to paddle/legacy/capi/examples/model_inference/sparse_binary/.gitignore diff --git a/paddle/capi/examples/model_inference/sparse_binary/CMakeLists.txt b/paddle/legacy/capi/examples/model_inference/sparse_binary/CMakeLists.txt similarity index 100% rename from paddle/capi/examples/model_inference/sparse_binary/CMakeLists.txt rename to paddle/legacy/capi/examples/model_inference/sparse_binary/CMakeLists.txt diff --git a/paddle/capi/examples/model_inference/sparse_binary/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh similarity index 100% rename from paddle/capi/examples/model_inference/sparse_binary/convert_protobin.sh rename to paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh diff --git a/paddle/capi/examples/model_inference/sparse_binary/main.c b/paddle/legacy/capi/examples/model_inference/sparse_binary/main.c similarity index 100% rename from paddle/capi/examples/model_inference/sparse_binary/main.c rename to paddle/legacy/capi/examples/model_inference/sparse_binary/main.c diff --git a/paddle/capi/examples/model_inference/sparse_binary/trainer_config.py b/paddle/legacy/capi/examples/model_inference/sparse_binary/trainer_config.py similarity index 100% rename from paddle/capi/examples/model_inference/sparse_binary/trainer_config.py rename to paddle/legacy/capi/examples/model_inference/sparse_binary/trainer_config.py diff --git a/paddle/capi/gradient_machine.cpp b/paddle/legacy/capi/gradient_machine.cpp similarity index 100% rename from paddle/capi/gradient_machine.cpp rename to paddle/legacy/capi/gradient_machine.cpp diff --git a/paddle/capi/gradient_machine.h b/paddle/legacy/capi/gradient_machine.h similarity index 100% rename from paddle/capi/gradient_machine.h rename to paddle/legacy/capi/gradient_machine.h diff --git a/paddle/capi/main.h b/paddle/legacy/capi/main.h similarity index 100% rename from paddle/capi/main.h rename to paddle/legacy/capi/main.h diff --git a/paddle/capi/matrix.h b/paddle/legacy/capi/matrix.h similarity index 100% rename from paddle/capi/matrix.h rename to paddle/legacy/capi/matrix.h diff --git a/paddle/capi/paddle_capi.map b/paddle/legacy/capi/paddle_capi.map similarity index 100% rename from paddle/capi/paddle_capi.map rename to paddle/legacy/capi/paddle_capi.map diff --git a/paddle/capi/tests/.gitignore b/paddle/legacy/capi/tests/.gitignore similarity index 100% rename from paddle/capi/tests/.gitignore rename to paddle/legacy/capi/tests/.gitignore diff --git a/paddle/capi/tests/CMakeLists.txt b/paddle/legacy/capi/tests/CMakeLists.txt similarity index 100% rename from paddle/capi/tests/CMakeLists.txt rename to paddle/legacy/capi/tests/CMakeLists.txt diff --git a/paddle/capi/tests/test_Arguments.cpp b/paddle/legacy/capi/tests/test_Arguments.cpp similarity index 100% rename from paddle/capi/tests/test_Arguments.cpp rename to paddle/legacy/capi/tests/test_Arguments.cpp diff --git a/paddle/capi/tests/test_GradientMachine.cpp b/paddle/legacy/capi/tests/test_GradientMachine.cpp similarity index 100% rename from paddle/capi/tests/test_GradientMachine.cpp rename to paddle/legacy/capi/tests/test_GradientMachine.cpp diff --git a/paddle/capi/tests/test_Matrix.cpp b/paddle/legacy/capi/tests/test_Matrix.cpp similarity index 100% rename from paddle/capi/tests/test_Matrix.cpp rename to paddle/legacy/capi/tests/test_Matrix.cpp diff --git a/paddle/capi/tests/test_Vector.cpp b/paddle/legacy/capi/tests/test_Vector.cpp similarity index 100% rename from paddle/capi/tests/test_Vector.cpp rename to paddle/legacy/capi/tests/test_Vector.cpp diff --git a/paddle/capi/tests/test_predict_network.py b/paddle/legacy/capi/tests/test_predict_network.py similarity index 100% rename from paddle/capi/tests/test_predict_network.py rename to paddle/legacy/capi/tests/test_predict_network.py diff --git a/paddle/capi/vector.h b/paddle/legacy/capi/vector.h similarity index 100% rename from paddle/capi/vector.h rename to paddle/legacy/capi/vector.h diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index b66a05aaebda645196721fd6ed840e5584813348..d8f0b76b7ba0fedfe411aa86f6f8a0c77a02beca 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -106,7 +106,7 @@ function cmake_gen() { -DWITH_FLUID_ONLY=${WITH_FLUID_ONLY:-OFF} -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DWITH_CONTRIB=${WITH_CONTRIB:-ON} - -DWITH_ANAKIN=${WITH_ANAKIN:-ON} + -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} -DWITH_INFERENCE_DEMO=${WITH_INFERENCE_DEMO:-ON} ======================================== EOF @@ -135,7 +135,7 @@ EOF -DWITH_FLUID_ONLY=${WITH_FLUID_ONLY:-OFF} \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DWITH_CONTRIB=${WITH_CONTRIB:-ON} \ - -DWITH_ANAKIN=${WITH_ANAKIN:-ON} \ + -DWITH_ANAKIN=${WITH_ANAKIN:-OFF} \ -DWITH_INFERENCE_DEMO=${WITH_INFERENCE_DEMO:-ON} } diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index d94564e11f982575dd9c065deb20d29396203227..5c8f4f6507c7dd9b3d005639d962ce1e55b2c704 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -18,7 +18,7 @@ import time import shutil from paddle.fluid.evaluator import Evaluator -from paddle.fluid.framework import Program, Parameter, default_main_program, Variable +from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable from . import core __all__ = [ @@ -1374,3 +1374,101 @@ def get_latest_checkpoint_serial(checkpoint_dir): if success_num > current_dir: current_dir = success_num return current_dir + + +def get_test_program(filelist, program=None, startup_program=None): + """ + Transpile current train program to a program to read test dataset + if the program is using reader ops like "open_files_op". + """ + + def _copy_reader_var_(block, var, new_name=None): + if new_name == None: + new_name = var.name + new_var = block.create_var( + name=str(new_name), type=core.VarDesc.VarType.READER) + new_var.desc.set_shapes(var.desc.shapes()) + new_var.desc.set_dtypes(var.desc.dtypes()) + new_var.persistable = True + return new_var + + def _get_test_reader_name(train_reader_name): + return train_reader_name + "_test" + + def _is_reader_op(op): + block = op.block + if "Out" in op.output_names: + reader_out = block.vars[op.output("Out")[0]] + if reader_out.type == core.VarDesc.VarType.READER: + return True + return False + + if program == None: + program = default_main_program() + if startup_program == None: + startup_program = default_startup_program() + startup_block = startup_program.global_block() + + # 1. find out the orignal reader var name + startup_reader_op_list = [] + + for op in startup_block.ops: + if _is_reader_op(op): + startup_reader_op_list.append(op) + + if len(startup_reader_op_list) == 0: + return program + + root_reader_op = startup_reader_op_list[0] + train_test_reader_map = {} + # 2. add operators to startup to read open and read test data files + for op in startup_reader_op_list: + assert (len(op.output("Out")) == 1) + train_reader_name = op.output("Out")[0] + train_reader = startup_block.vars[train_reader_name] + test_reader = _copy_reader_var_( + startup_block, + train_reader, + new_name=_get_test_reader_name(train_reader_name)) + train_test_reader_map[train_reader.name] = test_reader + + test_op_inputs = {} + for name in op.input_names: + train_arg_names = op.input(name) + test_arg_vars = [] + for arg_name in train_arg_names: + arg_var = train_test_reader_map[ + arg_name] if name == "UnderlyingReader" else startup_block.vars[ + arg_name] + test_arg_vars.append(arg_var) + test_op_inputs[name] = test_arg_vars + + test_op = startup_block.append_op( + type=op.type, + inputs=test_op_inputs, + outputs={'Out': [test_reader]}, + attrs=op.attrs) + # root reader op's filelist attr for read test files + if op.type == root_reader_op.type: + test_op.set_attr("file_names", filelist) + if op.type == "create_multi_pass_reader": + test_op.set_attr("pass_num", 1) + + # 3. rename reader vars in inference program to different name + # to avoid read from train data. + main_block = program.global_block() + for var in main_block.vars.values(): + if var.type == core.VarDesc.VarType.READER: + main_block.rename_var( + str(var.name), str(_get_test_reader_name(var.name))) + + for op in main_block.ops: + if op.type == root_reader_op.type: + test_op.set_attr("file_names", filelist) + if op.type == "create_multi_pass_reader": + test_op.set_attr("pass_num", 1) + + startup_program.sync_with_cpp() + program.sync_with_cpp() + + return program diff --git a/python/paddle/fluid/tests/unittests/.gitignore b/python/paddle/fluid/tests/unittests/.gitignore index 3538a9c2009bb133609153427981fb66974377fa..b1e8fda03aa42f5f7528eafb46c16d55b868bae5 100644 --- a/python/paddle/fluid/tests/unittests/.gitignore +++ b/python/paddle/fluid/tests/unittests/.gitignore @@ -4,3 +4,5 @@ mnist_1.recordio mnist_2.recordio flowers.recordio wmt16.recordio +data_balance_test.recordio +data_balance_with_lod_test.recordio diff --git a/python/paddle/fluid/tests/unittests/test_data_balance.py b/python/paddle/fluid/tests/unittests/test_data_balance.py new file mode 100644 index 0000000000000000000000000000000000000000..b558d7c2ea172d9c7526c865a4bc54c32f8998b6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_data_balance.py @@ -0,0 +1,187 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +import paddle.v2 as paddle +import numpy as np + + +class TestDataBalance(unittest.TestCase): + def prepare_data(self): + def fake_data_generator(): + for n in xrange(self.total_ins_num): + yield np.ones((3, 4)) * n, n + + # Prepare data + with fluid.program_guard(fluid.Program(), fluid.Program()): + reader = paddle.batch( + fake_data_generator, batch_size=self.batch_size) + feeder = fluid.DataFeeder( + feed_list=[ + fluid.layers.data( + name='image', shape=[3, 4], dtype='float32'), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batches = fluid.recordio_writer.convert_reader_to_recordio_file( + self.data_file_name, reader, feeder) + + def prepare_lod_data(self): + def fake_data_generator(): + for n in xrange(1, self.total_ins_num + 1): + d1 = (np.ones((n, 3)) * n).astype('float32') + d2 = (np.array(n).reshape((1, 1))).astype('int32') + yield d1, d2 + + # Prepare lod data + with fluid.program_guard(fluid.Program(), fluid.Program()): + with fluid.recordio_writer.create_recordio_writer( + filename=self.lod_data_file_name) as writer: + eof = False + generator = fake_data_generator() + while (not eof): + data_batch = [ + np.array([]).reshape((0, 3)), np.array([]).reshape( + (0, 1)) + ] + lod = [0] + for _ in xrange(self.batch_size): + try: + ins = generator.next() + except StopIteration: + eof = True + break + for i, d in enumerate(ins): + data_batch[i] = np.concatenate( + (data_batch[i], d), axis=0) + lod.append(lod[-1] + ins[0].shape[0]) + if data_batch[0].shape[0] > 0: + for i, d in enumerate(data_batch): + t = fluid.LoDTensor() + t.set(data_batch[i], fluid.CPUPlace()) + if i == 0: + t.set_lod([lod]) + writer.append_tensor(t) + writer.complete_append_tensor() + + def setUp(self): + self.use_cuda = fluid.core.is_compiled_with_cuda() + self.data_file_name = './data_balance_test.recordio' + self.lod_data_file_name = './data_balance_with_lod_test.recordio' + self.total_ins_num = 50 + self.batch_size = 10 + self.prepare_data() + self.prepare_lod_data() + + def main(self): + main_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(main_prog, startup_prog): + data_reader = fluid.layers.io.open_files( + filenames=[self.data_file_name], + shapes=[[-1, 3, 4], [-1, 1]], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + if self.use_cuda: + data_reader = fluid.layers.double_buffer(data_reader) + image, label = fluid.layers.read_file(data_reader) + + place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + parallel_exe = fluid.ParallelExecutor( + use_cuda=self.use_cuda, main_program=main_prog) + + if (parallel_exe.device_count > self.batch_size): + print("WARNING: Unittest TestDataBalance skipped. \ + For the result is not correct when device count \ + is larger than batch size.") + exit(0) + fetch_list = [image.name, label.name] + + data_appeared = [False] * self.total_ins_num + while (True): + try: + image_val, label_val = parallel_exe.run(fetch_list, + return_numpy=True) + except fluid.core.EnforceNotMet as ex: + self.assertIn("There is no next data.", ex.message) + break + ins_num = image_val.shape[0] + broadcasted_label = np.ones( + (ins_num, 3, 4)) * label_val.reshape((ins_num, 1, 1)) + self.assertEqual(image_val.all(), broadcasted_label.all()) + for l in label_val: + self.assertFalse(data_appeared[l[0]]) + data_appeared[l[0]] = True + for i in data_appeared: + self.assertTrue(i) + + def main_lod(self): + main_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(main_prog, startup_prog): + data_reader = fluid.layers.io.open_files( + filenames=[self.lod_data_file_name], + shapes=[[-1, 3], [-1, 1]], + lod_levels=[1, 0], + dtypes=['float32', 'int32'], + thread_num=1) + ins, label = fluid.layers.read_file(data_reader) + + place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_prog) + + parallel_exe = fluid.ParallelExecutor( + use_cuda=self.use_cuda, main_program=main_prog) + + if (parallel_exe.device_count > self.batch_size): + print("WARNING: Unittest TestDataBalance skipped. \ + For the result is not correct when device count \ + is larger than batch size.") + exit(0) + fetch_list = [ins.name, label.name] + + data_appeared = [False] * self.total_ins_num + while (True): + try: + ins_tensor, label_tensor = parallel_exe.run( + fetch_list, return_numpy=False) + except fluid.core.EnforceNotMet as ex: + self.assertIn("There is no next data.", ex.message) + break + + ins_val = np.array(ins_tensor) + label_val = np.array(label_tensor) + ins_lod = ins_tensor.lod()[0] + self.assertEqual(ins_val.shape[1], 3) + self.assertEqual(label_val.shape[1], 1) + self.assertEqual(len(ins_lod) - 1, label_val.shape[0]) + for i in range(0, len(ins_lod) - 1): + ins_elem = ins_val[ins_lod[i]:ins_lod[i + 1]][:] + label_elem = label_val[i][0] + self.assertEqual(ins_elem.all(), label_elem.all()) + self.assertFalse(data_appeared[int(label_elem - 1)]) + data_appeared[int(label_elem - 1)] = True + + for i in data_appeared: + self.assertTrue(i) + + def test_all(self): + self.main() + self.main_lod() diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index b4379ad447e01683325dfcbb6a5b322f0b8eac3d..75b4b4e50da04521021dcb1e97cfe495f2619433 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -15,51 +15,248 @@ import unittest import paddle.fluid as fluid from paddle.fluid.transpiler.distribute_transpiler import delete_ops +import traceback -from transpiler_test import TranspilerTest - -class TestDistTranspiler(TranspilerTest): +class TranspilerTest(unittest.TestCase): def setUp(self): - self.current_pserver_ep = "127.0.0.1:6174" + self.trainer_id = 0 + self.trainers = 2 + self.pservers = 2 + # NOTE: we do not actually bind this port + self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175" + self.pserver1_ep = "127.0.0.1:6174" + self.pserver2_ep = "127.0.0.1:6175" + self.slice_var_up = True + self.sync_mode = True + self.transpiler = None + + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) + sgd_optimizer.minimize(avg_cost) + return + + def get_main_program(self): + main = fluid.Program() + with fluid.program_guard(main): + self.net_conf() + self.origin_prog = main.clone() + return main + + def get_trainer(self): + t = self._transpiler_instance() + return t.get_trainer_program() + + def get_pserver(self, ep): + t = self._transpiler_instance() + pserver = t.get_pserver_program(ep) + startup = t.get_startup_program(ep, pserver) + return pserver, startup + + def _transpiler_instance(self): + if not self.transpiler: + main = self.get_main_program() + self.transpiler = fluid.DistributeTranspiler() + self.transpiler.transpile( + self.trainer_id, + program=main, + pservers=self.pserver_eps, + trainers=self.trainers, + slice_var_up=self.slice_var_up, + sync_mode=self.sync_mode) + return self.transpiler + +class TestBasicModel(TranspilerTest): def test_transpiler(self): + pserver, startup = self.get_pserver(self.pserver1_ep) + pserver2, startup2 = self.get_pserver(self.pserver2_ep) + trainer = self.get_trainer() - pserver, startup = self.get_pserver(self.current_pserver_ep) - self.assertEqual([op.type for op in trainer.global_block().ops], - self.get_expect_trainer_ops()) + + self.assertEqual([op.type for op in trainer.global_block().ops], [ + 'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean', + 'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad', + 'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send', + 'send_barrier', 'recv', 'recv', 'fetch_barrier', 'concat' + ]) self.assertEqual(len(pserver.blocks), 3) # block0: listen_and_serv self.assertEqual([op.type for op in pserver.blocks[0].ops], ["listen_and_serv"]) - # block2: optimize pass + # block1~2: optimize pass self.assertEqual([op.type for op in pserver.blocks[1].ops], ["sum", "scale", "sgd"]) - # confirm startup program - - self.assertEqual([op.type for op in startup.global_block().ops], [ - "fill_constant", "fill_constant", "uniform_random", "uniform_random" - ]) - + self.assertEqual([op.type for op in startup.global_block().ops], + ["fill_constant", "fill_constant", "uniform_random"]) # the variable #fc_w will be split into two blocks fc_w_var = startup.global_block().var("fc_w.block1") self.assertEqual(fc_w_var.shape, (500, 1000)) + # all parameters should be optimized on pserver + + pserver_params = [] + for prog in [pserver, pserver2]: + for blk in prog.blocks: + for op in blk.ops: + if "Param" in op.input_names: + param_name = op.input("Param")[0] + is_block_idx = param_name.find(".block") + if is_block_idx != -1: + origin_param_name = param_name[:is_block_idx] + else: + origin_param_name = param_name + pserver_params.append(origin_param_name) + trainer_params = [] + for op in self.origin_prog.global_block().ops: + if "Param" in op.input_names: + trainer_params.append(op.input("Param")[0]) + self.assertEqual(set(pserver_params), set(trainer_params)) + + +class TestNoSliceVar(TranspilerTest): + def setUp(self): + super(TestNoSliceVar, self).setUp() + self.slice_var_up = False + + def test_transpiler(self): + _, startup = self.get_pserver(self.pserver1_ep) + _, startup2 = self.get_pserver(self.pserver2_ep) + + if startup.global_block().vars.has_key("fc_w"): + fc_w_var = startup.global_block().vars["fc_w"] + elif startup2.global_block().vars.has_key("fc_w"): + fc_w_var = startup2.global_block().vars["fc_w"] + + self.assertEqual(fc_w_var.shape, (1000, 1000)) - def get_expect_trainer_ops(self): - trainer = fluid.Program() - with fluid.program_guard(trainer): - optimize_ops, params_grads = self.net_conf() +class TestLRDecay(TranspilerTest): + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.exponential_decay( + learning_rate=1.0, + decay_steps=2100, + decay_rate=0.1, + staircase=True)) + sgd_optimizer.minimize(avg_cost) + return + + def test_transpiler(self): + pserver, startup = self.get_pserver(self.pserver1_ep) + trainer = self.get_trainer() + + self.assertEqual(len(pserver.blocks), 4) + lr_decay_ops = [op.type for op in pserver.blocks[1].ops] + self.assertEqual(lr_decay_ops, [ + "increment", "cast", "fill_constant", "elementwise_div", "floor", + "fill_constant", "elementwise_pow", "fill_constant", + "elementwise_mul" + ]) + + +class TestLRDecayConditional(TranspilerTest): + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc(input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr(name='fc_w'), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=fluid.layers.piecewise_decay([10000, 20000], + [1.0, 0.5, 1.0])) + sgd_optimizer.minimize(avg_cost) + return + + def test_transpiler(self): + pserver, startup = self.get_pserver(self.pserver1_ep) + trainer = self.get_trainer() + + serv_op = pserver.blocks[0].ops[0] + sub_blocks = [] + optimize_blocks = [] + for b in serv_op.attrs["optimize_blocks"]: + optimize_blocks.append(b.idx) + for b in pserver.blocks: + if b.idx not in optimize_blocks: + sub_blocks.append(b.idx) + + self.assertEqual(len(pserver.blocks), 7) + lr_decay_ops = [op.type for op in pserver.blocks[1].ops] + self.assertEqual(lr_decay_ops, [ + "increment", "cast", "fill_constant", "fill_constant", "less_than", + "logical_not", "conditional_block", "fill_constant", + "fill_constant", "less_than", "logical_not", "logical_and", + "logical_and", "conditional_block", "fill_constant", + "conditional_block" + ]) + # test the condition blocks + for b in sub_blocks: + if b == 0: + continue + block = pserver.blocks[b] + self.assertEqual([op.type for op in block.ops], ["assign"]) + + +class TestL2Decay(TranspilerTest): + def net_conf(self): + x = fluid.layers.data(name='x', shape=[1000], dtype='float32') + y_predict = fluid.layers.fc( + input=x, + size=1000, + act=None, + param_attr=fluid.ParamAttr( + name='fc_w', + regularizer=fluid.regularizer.L2Decay(), + gradient_clip=fluid.clip.GradientClipByValue(0.1)), + bias_attr=fluid.ParamAttr(name='fc_b')) + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) + sgd_optimizer.minimize(avg_cost) + return + + def test_transpiler(self): + pserver, startup = self.get_pserver(self.pserver1_ep) + trainer = self.get_trainer() + + self.assertEqual(len(pserver.blocks), 3) + self.assertEqual([op.type for op in pserver.blocks[1].ops], + ["sum", "scale", "clip", "sgd"]) + self.assertEqual( + [op.type for op in pserver.blocks[2].ops], + ["sum", "scale", "clip", "scale", "elementwise_add", "sgd"]) + # TODO(typhoonzero): test clipping and L2Decay ops are removed from trainer + - delete_ops(trainer.global_block(), optimize_ops) - ops = [op.type for op in trainer.global_block().ops] + [ - "split_byref", "send", "send_barrier", "recv", "recv", - "fetch_barrier", "concat" - ] - ops.insert(ops.index("elementwise_add_grad") + 1, "send") - return ops + # FIXME(typhoonzero): need to add test for async case: + # see https://github.com/PaddlePaddle/Paddle/issues/11691 +class TestAsyncSGD(TranspilerTest): + pass if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_simple_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_simple_dist_transpiler.py deleted file mode 100644 index f4aa7426bc315be501348a64e2f15caed6dc8810..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/test_simple_dist_transpiler.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np - -import paddle.fluid as fluid -from paddle.fluid.transpiler.distribute_transpiler import delete_ops - -from transpiler_test import TranspilerTest - - -class TestSimpleDistTranspiler(TranspilerTest): - def setUp(self): - self.current_pserver_ep = "127.0.0.1:6175" - - def test_simple_transpiler(self): - np.random.seed(1) - - trainer = self.get_trainer() - pserver, startup = self.get_pserver(self.current_pserver_ep) - self.assertEqual([op.type for op in trainer.global_block().ops], - self.get_expect_trainer_ops()) - - self.assertEqual(len(pserver.blocks), 2) - # block0: listen_and_serv - self.assertEqual([op.type for op in pserver.blocks[0].ops], - ["listen_and_serv"]) - # block1: optimize pass - self.assertEqual([op.type for op in pserver.blocks[1].ops], - ["sum", "scale", "sgd"]) - - # confirm startup program - self.assertEqual([op.type for op in startup.global_block().ops], - ["fill_constant", "uniform_random", "uniform_random"]) - - # the variable #fc_w will NOT be splited - fc_w_var = startup.global_block().var("fc_w@GRAD") - self.assertEqual(fc_w_var.shape, (1000, 1000)) - - fc_w_var = startup.global_block().var("fc_w@GRAD.trainer_0") - self.assertEqual(fc_w_var.shape, (1000, 1000)) - - def get_expect_trainer_ops(self): - trainer = fluid.Program() - - with fluid.program_guard(trainer): - optimize_ops, params_grads = self.net_conf() - - delete_ops(trainer.global_block(), optimize_ops) - ops = [op.type for op in trainer.global_block().ops] + [ - "send", "send_barrier", "recv", "recv", "fetch_barrier" - ] - ops.insert(ops.index("elementwise_add_grad") + 1, "send") - return ops - - def _transpiler_instance(self): - main = self.get_main_program() - t = fluid.DistributeTranspiler() - t.transpile( - self.trainer_id, - program=main, - pservers=self.pserver_eps, - trainers=self.trainers, - slice_var_up=False) - return t - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/transpiler_test.py b/python/paddle/fluid/tests/unittests/transpiler_test.py deleted file mode 100644 index d84c5d9c41c705cf6d14cc0b5a8c692b0d646337..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/transpiler_test.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np - -import paddle.fluid as fluid -import paddle.fluid.core as core -import paddle.fluid.layers as layers - - -class TranspilerTest(unittest.TestCase): - @classmethod - def setUpClass(self): - self.trainer_id = 0 - self.trainers = 2 - self.pservers = 2 - self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175" - - def net_conf(self): - x = fluid.layers.data(name='x', shape=[1000], dtype='float32') - - y_predict = fluid.layers.fc(input=x, - size=1000, - act=None, - param_attr=fluid.ParamAttr(name='fc_w')) - - y = fluid.layers.data(name='y', shape=[1], dtype='float32') - - cost = fluid.layers.square_error_cost(input=y_predict, label=y) - avg_cost = fluid.layers.mean(cost) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) - - optimize_ops, params_grads = sgd_optimizer.minimize(avg_cost) - return optimize_ops, params_grads - - def get_main_program(self): - main = fluid.Program() - - with fluid.program_guard(main): - self.net_conf() - - return main - - def get_trainer(self): - return self._transpiler_instance().get_trainer_program() - - def get_pserver(self, ep): - t = self._transpiler_instance() - pserver = t.get_pserver_program(ep) - startup = t.get_startup_program(ep, pserver) - return pserver, startup - - def _transpiler_instance(self): - main = self.get_main_program() - t = fluid.DistributeTranspiler() - t.transpile( - self.trainer_id, - program=main, - pservers=self.pserver_eps, - trainers=self.trainers) - return t diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index f086600702eeafc7948e168d77dfbd1d1c4b901c..53d6ca86a008f798af2854a154cce8b7242d2f35 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -455,6 +455,8 @@ class DistributeTranspiler(object): __append_optimize_op__(op, per_opt_block, grad_to_block_id, merged_var, lr_ops) + # dedup grad to ids list + grad_to_block_id = list(set(grad_to_block_id)) # append global ops if global_ops: opt_state_block = pserver_program.create_block( @@ -960,8 +962,6 @@ class DistributeTranspiler(object): if not block_map.has_key(varname): block_map[varname] = [] block_map[varname].append((long(offset), long(size))) - # Do not remove this important debug message: - print("block map: %s" % block_map) for varname, splited in block_map.iteritems(): orig_var = program.global_block().var(varname) @@ -1401,6 +1401,16 @@ class DistributeTranspiler(object): break return lr_ops + def _is_opt_role_op(self, op): + # NOTE: depend on oprole to find out whether this op is for + # optimize + op_maker = core.op_proto_and_checker_maker + optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize + if op_maker.kOpRoleAttrName() in op.attrs and \ + int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role): + return True + return False + def _get_optimize_pass(self): """ Get optimizer operators, paramters and gradients from origin_program @@ -1413,10 +1423,7 @@ class DistributeTranspiler(object): params_grads = [] origin_var_dict = self.origin_program.global_block().vars for op in block.ops: - # NOTE(Yancey1989): we can not use op role to distinguish an optimizer op - # or not, because all ops in optimizer sub-graph would - # sign the optimizer op role - if self._is_optimizer_op(op): + if self._is_opt_role_op(op): opt_ops.append(op) # HACK(wuyi): if we find grad vars from input of optimize # ops, we may get the output of clip op. Use syntax "@GRAD"