提交 0f752e89 编写于 作者: S sandyhouse

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_timeline

......@@ -95,9 +95,10 @@ void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
const std::string& fs_ugi) {
fs_name_ = fs_name;
fs_ugi_ = fs_ugi;
std::string cmd = std::string("hadoop fs");
std::string cmd = std::string("$HADOOP_HOME/bin/hadoop fs");
cmd += " -D fs.default.name=" + fs_name;
cmd += " -D hadoop.job.ugi=" + fs_ugi;
cmd += " -Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=500000";
paddle::framework::hdfs_set_command(cmd);
}
......
......@@ -3,6 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_async_op_handle SRCS fetch_async_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(share_tensor_buffer_functor SRCS share_tensor_buffer_functor.cc DEPS framework_proto scope place operator op_registry)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
......@@ -98,7 +99,7 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
cc_test(exception_holder_test SRCS exception_holder_test.cc )
......
......@@ -18,7 +18,8 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -120,6 +121,11 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
for (auto &place : places_) {
fetch_ctxs_.Get(place)->Wait();
}
return fetches;
}
......@@ -162,8 +168,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_, return_merged);
auto *op = new FetchAsyncOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_, return_merged);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
......@@ -174,6 +180,14 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var);
}
for (auto *var : vars) {
auto *op = var->GeneratedOp();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op) {
compute_op->SetLockAndRecordEventFree(false);
}
}
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
......@@ -261,7 +275,7 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchAsyncOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace framework {
namespace details {
FetchAsyncOpHandle::FetchAsyncOpHandle(ir::Node *node, FetchResultType *data,
size_t offset,
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes,
bool return_merged)
: OpHandleBase(node),
data_(data),
offset_(offset),
local_scopes_(local_scopes),
local_exec_scopes_(local_exec_scopes),
return_merged_(return_merged) {}
FetchAsyncOpHandle::~FetchAsyncOpHandle() {}
void FetchAsyncOpHandle::RecordWaitEventOnCtx(
platform::DeviceContext *waited_ctx) {
PADDLE_THROW(platform::errors::PermissionDenied(
"No nodes need to wait FetchAsyncOp. Unexpceted Error."));
}
static void CheckTensorAttrs(const LoDTensor *tensor,
const proto::VarType::Type &type,
const DataLayout &layout, const DDim &dims,
const LoD &lod, const size_t offset) {
if (tensor->numel() && tensor->IsInitialized()) {
// step1: check type
PADDLE_ENFORCE_EQ(
type, tensor->type(),
platform::errors::InvalidArgument(
"The data type of fetched Tensors or the items of fetched "
"LoDTensorArray are different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
DataTypeToString(type), DataTypeToString(tensor->type()), offset));
// step2: check layout
PADDLE_ENFORCE_EQ(
layout, tensor->layout(),
platform::errors::InvalidArgument(
"The layout of fetched Tensors or the items of fetched "
"LoDTensorArray are different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
DataLayoutToString(layout), DataLayoutToString(tensor->layout()),
offset));
}
// step3: check dims
auto tensor_dims = tensor->dims();
PADDLE_ENFORCE_EQ(dims.size(), tensor_dims.size(),
platform::errors::InvalidArgument(
"The dimension sizes of fetched Tensors or "
"the items of fetched LoDTensorArray are "
"different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
dims, tensor_dims, offset));
for (int j = 1; j < dims.size(); j++) {
PADDLE_ENFORCE_EQ(dims[j], tensor_dims[j],
platform::errors::InvalidArgument(
"The dimensions of fetched Tensors or "
"the items of fetched LoDTensorArray are "
"different from each other on different "
"devices(%s vs %s). And the error is caused by the "
"%zu (th) fetched variable. Please set the "
"parameter `return_merged = False` when "
"you call the `Executor.run()` method.",
dims, tensor_dims, offset));
}
// step4: check lod
PADDLE_ENFORCE_EQ(
lod.size(), tensor->lod().size(),
platform::errors::InvalidArgument(
"The LoD information of fetched Tensors or the items of fetched "
"LoDTensorArray are different from each other on different "
"devices(%s vs %s). And the error is caused by the %zu "
"(th) fetched variable. Please set the "
"parameter `return_merged = False` when you "
"call the `Executor.run()` method.",
lod, tensor->lod(), offset));
}
static void TransData(const framework::Tensor *src_item,
framework::Tensor *dst_item,
const platform::DeviceContext &ctx) {
if (src_item->IsInitialized() && src_item->numel() > 0) {
if (platform::is_gpu_place(src_item->place())) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(*src_item, platform::CUDAPinnedPlace(), ctx, dst_item);
#endif
} else {
TensorCopy(*src_item, platform::CPUPlace(), dst_item);
}
}
}
void FetchAsyncOpHandle::FetchMergedLodTensor(
const std::vector<const LoDTensor *> &src_lodtensors,
LoDTensor *dst_lodtensor) {
// calc dst type,layout,dim,lod and calc check dim
proto::VarType::Type new_type = proto::VarType::FP32;
framework::DataLayout new_layout;
framework::DDim new_dim;
LoD new_lod = src_lodtensors[0]->lod();
framework::DDim check_dim;
for (auto *t : src_lodtensors) {
if (t->numel() && t->IsInitialized()) {
check_dim = t->dims();
new_type = t->type();
new_layout = t->layout();
break;
}
}
bool find_first_dims = false;
for (auto *t : src_lodtensors) {
if (t->numel() && t->IsInitialized()) {
if (!find_first_dims) {
new_dim = t->dims();
find_first_dims = true;
} else {
new_dim[0] += t->dims()[0];
}
}
}
// check src type,layout,dim,lod consistence
for (size_t i = 1; i < src_lodtensors.size(); ++i) {
CheckTensorAttrs(src_lodtensors[i], new_type, new_layout, check_dim,
new_lod, offset_);
}
// set dst tensor
dst_lodtensor->Resize(new_dim);
dst_lodtensor->set_layout(src_lodtensors[0]->layout());
dst_lodtensor->set_lod(src_lodtensors[0]->lod());
if (platform::is_gpu_place(src_lodtensors[0]->place())) {
dst_lodtensor->mutable_data(platform::CUDAPinnedPlace(),
src_lodtensors[0]->type());
} else {
dst_lodtensor->mutable_data(platform::CPUPlace(),
src_lodtensors[0]->type());
}
// slice and memcpy
int begin = 0;
for (auto *src : src_lodtensors) {
int end = begin + src->dims()[0];
if (end == begin) {
continue;
}
auto dst = dst_lodtensor->Slice(begin, end);
TransData(src, &dst, *dev_ctxes_[src->place()]);
begin = end;
}
}
void FetchAsyncOpHandle::RunImpl() {
platform::RecordEvent record_event(Name());
WaitInputVarGenerated();
// get src vars
auto &scopes = *local_exec_scopes_;
std::vector<Variable *> src_vars;
src_vars.reserve(inputs_.size());
for (size_t i = 0; i < inputs_.size(); ++i) {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx());
auto *var = scope->FindVar(var_handle->name());
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound(
"Cannot find variable %s in execution scope.", var_handle->name()));
src_vars.emplace_back(var);
}
if (return_merged_) {
auto &val = BOOST_GET(FetchList, *data_);
if (src_vars[0]->IsType<LoDTensor>()) {
// to lodtensor type
std::vector<const LoDTensor *> src_lodtensors;
src_lodtensors.reserve(src_vars.size());
for (size_t i = 0; i < src_vars.size(); ++i) {
src_lodtensors.emplace_back(&src_vars[i]->Get<framework::LoDTensor>());
}
LoDTensor dst_lodtensor;
FetchMergedLodTensor(src_lodtensors, &dst_lodtensor);
val.at(offset_) = std::move(dst_lodtensor);
} else {
// to lodtensorarray type
std::vector<const LoDTensorArray *> src_lodtensor_arrays;
src_lodtensor_arrays.reserve(src_vars.size());
for (size_t i = 0; i < src_vars.size(); ++i) {
src_lodtensor_arrays.emplace_back(
&src_vars[i]->Get<framework::LoDTensorArray>());
}
LoDTensorArray dst_lodtensor_array;
dst_lodtensor_array.resize(src_lodtensor_arrays[0]->size());
for (size_t i = 0; i < dst_lodtensor_array.size(); ++i) {
std::vector<const LoDTensor *> src_lodtensors;
src_lodtensors.reserve(src_lodtensor_arrays.size());
for (size_t j = 0; j < src_lodtensor_arrays.size(); ++j) {
src_lodtensors.emplace_back(&(*src_lodtensor_arrays[j])[i]);
}
FetchMergedLodTensor(src_lodtensors, &dst_lodtensor_array[i]);
}
val.at(offset_) = std::move(dst_lodtensor_array);
}
} else {
auto &val = BOOST_GET(FetchUnmergedList, *data_);
auto &dst_tensors = val.at(offset_);
dst_tensors.reserve(src_vars.size());
for (size_t i = 0; i < src_vars.size(); ++i) {
if (src_vars[i]->IsType<LoDTensor>()) {
auto &t = src_vars[i]->Get<framework::LoDTensor>();
LoDTensor item;
TransData(&t, &item, *dev_ctxes_[t.place()]);
dst_tensors.emplace_back(std::move(item));
} else {
auto &t = src_vars[i]->Get<framework::LoDTensorArray>();
LoDTensorArray item;
item.resize(t.size());
for (size_t j = 0; j < t.size(); ++j) {
TransData(&t[j], &item[j], *dev_ctxes_[t[j].place()]);
}
dst_tensors.emplace_back(std::move(item));
}
}
}
}
bool FetchAsyncOpHandle::IsMultiDeviceTransfer() { return true; }
std::string FetchAsyncOpHandle::Name() const { return "FetchAsync"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FetchAsyncOpHandle : public OpHandleBase {
public:
FetchAsyncOpHandle(ir::Node *node, FetchResultType *data, size_t offset,
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes,
bool return_merged);
~FetchAsyncOpHandle();
void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
std::string Name() const override;
bool IsMultiDeviceTransfer() override;
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return *local_scopes_; }
void FetchMergedLodTensor(
const std::vector<const LoDTensor *> &src_lodtensors,
LoDTensor *dst_lodtensor);
private:
FetchResultType *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<Scope *> *local_exec_scopes_;
bool return_merged_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -36,7 +36,8 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FetchResultType *data,
FetchOpHandle::~FetchOpHandle() {}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
PADDLE_THROW(platform::errors::PermissionDenied(
"No nodes need to wait FetchOp. Unexpceted Error."));
}
static void CheckDims(const framework::DDim &tensor_dims,
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
namespace paddle {
namespace framework {
......@@ -23,9 +24,11 @@ void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<FetchOpHandle*>(op),
"The input ops of ClearFetchOp function should be FetchOpHandle.");
PADDLE_ENFORCE_EQ(dynamic_cast<FetchOpHandle*>(op) != nullptr ||
dynamic_cast<FetchAsyncOpHandle*>(op) != nullptr,
true,
"The input ops of ClearFetchOp function should be "
"FetchOpHandle or FetchAsyncOpHandle.");
for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var);
}
......
......@@ -857,7 +857,7 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
float* g = g_tensor->data<float>();
if (scale_sparse_gradient_with_batch_size_ && grad_dim > 0) {
int dim = emb_dim + offset;
int dim = emb_dim;
Eigen::Map<
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
g_mat(g, g_tensor->numel() / dim, dim);
......@@ -1170,6 +1170,21 @@ void FleetWrapper::LoadModelOneTable(const uint64_t table_id,
#endif
}
void FleetWrapper::LoadWithWhitelist(const uint64_t table_id,
const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->load_with_whitelist(table_id, path,
std::to_string(mode));
ret.wait();
if (ret.get() != 0) {
LOG(ERROR) << "load model of table id: " << table_id
<< ", from path: " << path << " failed";
}
#else
VLOG(0) << "FleetWrapper::LoadWhitelist does nothing when no pslib";
#endif
}
void FleetWrapper::SaveModel(const std::string& path, const int mode) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save(path, std::to_string(mode));
......@@ -1285,6 +1300,26 @@ int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
#endif
}
int32_t FleetWrapper::SaveWithWhitelist(int table_id, const std::string& path,
const int mode,
const std::string& whitelist_path) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->save_with_whitelist(
table_id, path, std::to_string(mode), whitelist_path);
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
#else
VLOG(0) << "FleetWrapper::SaveCache does nothing when no pslib";
return -1;
#endif
}
void FleetWrapper::ShrinkSparseTable(int table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->shrink(table_id);
......
......@@ -273,6 +273,11 @@ class FleetWrapper {
// save cache model
// cache model can speed up online predict
int32_t SaveCache(int table_id, const std::string& path, const int mode);
// save sparse table filtered by user-defined whitelist
int32_t SaveWithWhitelist(int table_id, const std::string& path,
const int mode, const std::string& whitelist_path);
void LoadWithWhitelist(const uint64_t table_id, const std::string& path,
const int mode);
// copy feasign key/value from src_table_id to dest_table_id
int32_t CopyTable(const uint64_t src_table_id, const uint64_t dest_table_id);
// copy feasign key/value from src_table_id to dest_table_id
......
......@@ -21,6 +21,8 @@
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace imperative {
......@@ -47,6 +49,9 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward) {
VLOG(1) << "Trace Op: " << type;
if (FLAGS_use_mkldnn) {
attrs["use_mkldnn"] = true;
}
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
const auto& op_info = op->Info();
auto* attr_checker = op_info.Checker();
......
......@@ -1058,6 +1058,7 @@ USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
#endif
namespace paddle_infer {
......
......@@ -3,8 +3,8 @@ nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
......@@ -58,6 +58,24 @@ class ScaleOpConverter : public OpConverter {
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
nvinfer1::ILayer* layer = nullptr;
auto input_dim = input->getDimensions();
PADDLE_ENFORCE_GE(input_dim.nbDims, 3,
platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >= 3"));
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims
expand_layer->setReshapeDimensions(target_shape);
input = expand_layer->getOutput(0);
}
if (bias_after_scale) {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM,
......@@ -73,6 +91,18 @@ class ScaleOpConverter : public OpConverter {
power_weights.get(), scale_weights.get(), power_weights.get());
}
PADDLE_ENFORCE_EQ(layer != nullptr, true,
platform::errors::Fatal("Create scale layer failed."));
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims
squeeze_layer->setReshapeDimensions(target_shape);
layer = static_cast<nvinfer1::ILayer*>(squeeze_layer);
}
RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode);
}
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Stack converter from fluid to tensorRT.
*/
class StackOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid stack op to tensorrt stack layer";
framework::OpDesc op_desc(op, nullptr);
auto input = op_desc.Input("X");
int input_num = input.size();
nvinfer1::ITensor** inputs =
(nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*));
for (int i = 0; i < input_num; ++i) {
inputs[i] = engine_->GetITensor(input[i]);
}
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
if (axis < 0) {
axis = axis + inputs[0]->getDimensions().nbDims + 1;
}
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
plugin::StackPluginDynamic* plugin =
new plugin::StackPluginDynamic(axis, input_num);
layer = engine_->AddPluginV2(inputs, input_num, plugin);
assert(layer != nullptr);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Y").front();
RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode);
free(inputs);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(stack, StackOpConverter);
......@@ -186,6 +186,14 @@ void TensorRTEngine::FreezeNetwork() {
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
}
infer_builder_config_->addOptimizationProfile(optim_profile_);
infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
if (enable_int8) {
// Due to a bug of TRT, we must set precision BuilderFlag to kFP16 before
// kINT8 here to perform INT8 inference.
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kINT8);
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES);
}
if (WithFp16()) {
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
if (disable_trt_plugin_fp16()) {
......
......@@ -51,6 +51,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"relu",
"depthwise_conv2d",
"softmax",
"sigmoid",
"batch_norm",
"elementwise_add",
"leaky_relu",
......@@ -87,6 +88,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"gelu",
"layer_norm",
"scale",
"stack",
};
};
......
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
hard_swish_op_plugin.cu stack_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
......@@ -104,32 +104,51 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions(
auto stri_0 = expr_builder.constant(strides_[0]);
auto stri_1 = expr_builder.constant(strides_[1]);
auto one_value = expr_builder.constant(1);
auto tmp1_0 =
expr_builder.constant((-ksize_[0] + 2 * paddings_[0]) / strides_[0] + 1);
auto tmp1_1 =
expr_builder.constant((-ksize_[1] + 2 * paddings_[1]) / strides_[1] + 1);
auto v0_tmp = expr_builder.constant(-ksize_[0] + 2 * paddings_[0]);
auto v1_tmp = expr_builder.constant(-ksize_[1] + 2 * paddings_[1]);
auto tmp2_0 = expr_builder.constant(
(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1) / strides_[0] + 1);
auto tmp2_1 = expr_builder.constant(
(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1) / strides_[1] + 1);
auto *a_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV,
*inputs[0].d[2], *stri_0);
auto *b_d = expr_builder.operation(nvinfer1::DimensionOperation::kCEIL_DIV,
*inputs[0].d[3], *stri_1);
auto ceil_tmp =
expr_builder.constant(-ksize_[0] + 2 * paddings_[0] + strides_[0] - 1);
auto ceil1_tmp =
expr_builder.constant(-ksize_[1] + 2 * paddings_[1] + strides_[1] - 1);
if (!ceil_mode_) {
output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*a_d, *tmp1_0);
output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*b_d, *tmp1_1);
output.d[2] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[2], *v0_tmp),
*stri_0),
*one_value);
output.d[3] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[3], *v1_tmp),
*stri_1),
*one_value);
} else {
output.d[2] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*a_d, *tmp2_0);
output.d[3] = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*b_d, *tmp2_1);
output.d[2] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[2], *ceil_tmp),
*stri_0),
*one_value);
output.d[3] = expr_builder.operation(
nvinfer1::DimensionOperation::kSUM,
*expr_builder.operation(
nvinfer1::DimensionOperation::kFLOOR_DIV,
*expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*inputs[0].d[3], *ceil1_tmp),
*stri_1),
*one_value);
}
return output;
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
StackPluginDynamic::StackPluginDynamic(int axis, int num_stack)
: axis_(axis), num_stack_(num_stack) {}
StackPluginDynamic::StackPluginDynamic(void const* serial_data,
size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &axis_);
DeserializeValue(&serial_data, &serial_length, &num_stack_);
}
StackPluginDynamic::~StackPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* StackPluginDynamic::clone() const {
return new StackPluginDynamic(axis_, num_stack_);
}
const char* StackPluginDynamic::getPluginType() const { return "stack_plugin"; }
int StackPluginDynamic::getNbOutputs() const { return 1; }
int StackPluginDynamic::initialize() { return 0; }
size_t StackPluginDynamic::getSerializationSize() const {
size_t serialize_size = 0;
serialize_size += SerializedSize(axis_);
serialize_size += SerializedSize(num_stack_);
return serialize_size;
}
void StackPluginDynamic::serialize(void* buffer) const {
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, num_stack_);
}
nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
nvinfer1::DimsExprs output(inputs[0]);
output.nbDims = inputs[0].nbDims + 1;
for (int i = inputs[0].nbDims; i > axis_; --i) {
output.d[i] = inputs[0].d[i - 1];
}
output.d[axis_] = expr_builder.constant(nb_inputs);
return output;
}
void StackPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) {}
size_t StackPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const {
return num_stack_ * sizeof(uintptr_t);
}
void StackPluginDynamic::destroy() { delete this; }
void StackPluginDynamic::terminate() {}
bool StackPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of stack plugin should not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType StackPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The index should be equal to 0"));
return input_types[0];
}
template <typename T>
__global__ void StackKernel(const T* const* input, T* output, int num_stack,
int base_unit) {
int stack_id = blockIdx.x;
int lead_id = blockIdx.y;
for (int i = threadIdx.x; i < base_unit; i += blockDim.x) {
output[lead_id * num_stack * base_unit + stack_id * base_unit + i] =
input[stack_id][lead_id * base_unit + i];
}
}
int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims; // (batch, seq, seq)
auto out_dims = output_desc[0].dims; // (batch, num_head, seq, seq)
auto out_num_dims = out_dims.nbDims;
int base_unit = 1;
for (int i = axis_ + 1; i < out_num_dims; ++i) {
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
platform::errors::InvalidArgument(
"Input dimensions should be greater than 0"));
base_unit *= out_dims.d[i];
}
int lead_unit = 1;
for (int i = 0; i < axis_; ++i) {
PADDLE_ENFORCE_GT(out_dims.d[i], 0,
platform::errors::InvalidArgument(
"Input dimensions should be greater than 0"));
lead_unit *= out_dims.d[i];
}
PADDLE_ENFORCE_EQ(
out_dims.d[axis_], num_stack_,
platform::errors::InvalidArgument("number of stack axis should be same"));
cudaMemcpyAsync(workspace, reinterpret_cast<const void* const>(inputs),
sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice,
stream);
const int num_stacks = out_dims.d[axis_];
dim3 num_blocks(num_stacks, lead_unit);
const int num_threads = 256;
auto infer_type = input_desc[0].type;
if (infer_type == nvinfer1::DataType::kFLOAT) {
float* output = static_cast<float*>(outputs[0]);
StackKernel<float><<<num_blocks, num_threads, 0, stream>>>(
reinterpret_cast<const float* const*>(workspace), output, num_stacks,
base_unit);
} else if (infer_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
__half* output = static_cast<__half*>(outputs[0]);
StackKernel<__half><<<num_blocks, num_threads, 0, stream>>>(
reinterpret_cast<const __half* const*>(workspace), output, num_stacks,
base_unit);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(
platform::errors::Fatal("The Stack TRT Plugin's input type only "
"support float or half currently."));
}
return cudaGetLastError() != cudaSuccess;
}
StackPluginDynamicCreator::StackPluginDynamicCreator() {}
const char* StackPluginDynamicCreator::getPluginName() const {
return "stack_plugin";
}
const char* StackPluginDynamicCreator::getPluginVersion() const { return "1"; }
const nvinfer1::PluginFieldCollection*
StackPluginDynamicCreator::getFieldNames() {
return &field_collection_;
}
nvinfer1::IPluginV2* StackPluginDynamicCreator::createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) {
int axis = -1;
int num_stack = -1;
for (int i = 0; i < fc->nbFields; ++i) {
const std::string name(fc->fields[i].name);
if (name == "axis") {
axis = static_cast<const int*>(fc->fields[i].data)[0];
} else if (name == "num_stack") {
num_stack = static_cast<const int*>(fc->fields[i].data)[0];
} else {
PADDLE_THROW(platform::errors::Fatal("Meet an unknown plugin field '" +
name +
"' when creating stack op plugin."));
}
}
return new StackPluginDynamic(axis, num_stack);
}
nvinfer1::IPluginV2* StackPluginDynamicCreator::deserializePlugin(
const char* name, const void* serial_data, size_t serial_length) {
auto plugin = new StackPluginDynamic(serial_data, serial_length);
return plugin;
}
void StackPluginDynamicCreator::setPluginNamespace(const char* lib_namespace) {
plugin_namespace_ = lib_namespace;
}
const char* StackPluginDynamicCreator::getPluginNamespace() const {
return plugin_namespace_.c_str();
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class StackPluginDynamic : public DynamicPluginTensorRT {
public:
explicit StackPluginDynamic(int axis, int num_stack);
StackPluginDynamic(void const* serial_data, size_t serial_length);
~StackPluginDynamic();
nvinfer1::IPluginV2DynamicExt* clone() const override;
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override;
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
const char* getPluginType() const override;
int getNbOutputs() const override;
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
void destroy() override;
private:
int axis_;
int num_stack_;
};
class StackPluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
StackPluginDynamicCreator();
const char* getPluginName() const override;
const char* getPluginVersion() const override;
const nvinfer1::PluginFieldCollection* getFieldNames() override;
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override;
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override;
void setPluginNamespace(const char* lib_namespace) override;
const char* getPluginNamespace() const override;
private:
std::string plugin_namespace_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(StackPluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -132,7 +132,9 @@ if(NOT APPLE AND WITH_MKLML)
set(SEQ_POOL1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/seq_pool")
download_model_and_data(${SEQ_POOL1_INSTALL_DIR} "seq_pool1_model_.tar.gz" "seq_pool1_data.txt.tar.gz")
inference_analysis_api_test(test_analyzer_seq_pool1 ${SEQ_POOL1_INSTALL_DIR} analyzer_seq_pool1_tester.cc)
set_tests_properties(test_analyzer_seq_pool1 PROPERTIES TIMEOUT 150)
if(NOT WIN32)
set_tests_properties(test_analyzer_seq_pool1 PROPERTIES TIMEOUT 150)
endif()
else()
# TODO: fix this test on MACOS and OPENBLAS, the reason is that
# fusion_seqexpand_concat_fc_op is not supported on MACOS and OPENBLAS
......@@ -192,8 +194,9 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz")
inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true)
set_tests_properties(test_analyzer_ernie_large PROPERTIES TIMEOUT 150 LABELS "RUN_TYPE=NIGHTLY")
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_analyzer_ernie_large PROPERTIES TIMEOUT 150 LABELS "RUN_TYPE=NIGHTLY")
endif()
# text_classification
set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification")
......@@ -215,7 +218,7 @@ inference_analysis_test(test_analyzer_transformer SRCS analyzer_transformer_test
# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
if (NOT EXISTS ${OCR_INSTALL_DIR})
if (NOT EXISTS ${OCR_INSTALL_DIR}/ocr.tar.gz)
inference_download_and_uncompress(${OCR_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Focr.tar.gz")
endif()
inference_analysis_api_test(test_analyzer_ocr ${OCR_INSTALL_DIR} analyzer_vis_tester.cc)
......@@ -231,7 +234,7 @@ set_property(TEST test_analyzer_detect PROPERTY ENVIRONMENT GLOG_vmodule=analysi
# mobilenet with transpose op
set(MOBILENET_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet")
if (NOT EXISTS ${MOBILENET_INSTALL_DIR})
if (NOT EXISTS ${MOBILENET_INSTALL_DIR}/mobilenet.tar.gz)
inference_download_and_uncompress(${MOBILENET_INSTALL_DIR} "http://paddlemodels.bj.bcebos.com/" "inference-vis-demos%2Fmobilenet.tar.gz")
endif()
inference_analysis_api_test(test_analyzer_mobilenet_transpose ${MOBILENET_INSTALL_DIR} analyzer_vis_tester.cc)
......@@ -395,15 +398,15 @@ inference_analysis_api_test(test_analyzer_bert ${BERT_INSTALL_DIR} analyzer_bert
if(WITH_GPU AND TENSORRT_FOUND)
set(TRT_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/trt_models")
if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR})
if (NOT EXISTS ${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models.tar.gz)
inference_download_and_uncompress(${TRT_MODEL_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "trt_inference_test_models.tar.gz")
endif()
set(TEST_SPLIT_CONVERTER_MODEL "${TRT_MODEL_INSTALL_DIR}/trt_split_op_converter_test")
if (NOT EXISTS ${TEST_SPLIT_CONVERTER_MODEL})
if (NOT EXISTS ${TEST_SPLIT_CONVERTER_MODEL}/split_converter.tgz)
inference_download_and_uncompress(${TEST_SPLIT_CONVERTER_MODEL} ${INFERENCE_URL}/tensorrt_test "split_converter.tgz")
endif()
set(TEST_INSTANCE_NORM_MODEL "${TRT_MODEL_INSTALL_DIR}/trt_instance_norm_test")
if (NOT EXISTS ${TEST_INSTANCE_NORM_MODEL})
if (NOT EXISTS ${TEST_INSTANCE_NORM_MODEL}/instance_norm.tgz)
inference_download_and_uncompress(${TEST_INSTANCE_NORM_MODEL} ${INFERENCE_URL}/tensorrt_test "instance_norm.tgz")
endif()
inference_analysis_test(trt_mobilenet_test SRCS trt_mobilenet_test.cc
......@@ -431,16 +434,16 @@ if(WITH_GPU AND TENSORRT_FOUND)
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
set(TRT_MODEL_QUANT_RESNET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant_small_model")
if (NOT EXISTS ${TRT_MODEL_QUANT_RESNET_DIR})
inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "quant_small_model.tar.gz")
set(TRT_MODEL_QUANT_RESNET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/small_quant_model")
if (NOT EXISTS ${TRT_MODEL_QUANT_RESNET_DIR}/small_quant_model.tgz)
inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "small_quant_model.tgz")
endif()
inference_analysis_test(trt_quant_int8_test SRCS trt_quant_int8_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_QUANT_RESNET_DIR})
set(TRT_MODEL_QUANT_YOLOV3_DIR "${INFERENCE_DEMO_INSTALL_DIR}/yolov3_r50_quant_aware")
if (NOT EXISTS ${TRT_MODEL_QUANT_YOLOV3_DIR})
if (NOT EXISTS ${TRT_MODEL_QUANT_YOLOV3_DIR}/yolov3_r50_quant_aware.tgz)
inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "yolov3_r50_quant_aware.tgz")
endif()
inference_analysis_test(trt_quant_int8_yolov3_r50_test SRCS trt_quant_int8_yolov3_r50_test.cc
......@@ -448,12 +451,12 @@ if(WITH_GPU AND TENSORRT_FOUND)
ARGS --infer_model=${TRT_MODEL_QUANT_YOLOV3_DIR})
set(TEST_TRT_DYNAMIC_MODEL2 "${TRT_MODEL_INSTALL_DIR}/complex_model_dynamic")
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2})
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL2}/complex_model_dynamic2.tar.gz)
inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL2} ${INFERENCE_URL}/tensorrt_test "complex_model_dynamic2.tar.gz")
endif()
set(TEST_TRT_DYNAMIC_MODEL "${TRT_MODEL_INSTALL_DIR}/conv_bn_swish_split_gelu")
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL})
if (NOT EXISTS ${TEST_TRT_DYNAMIC_MODEL}/conv_bn_swish_split_gelu.tar.gz)
inference_download_and_uncompress(${TEST_TRT_DYNAMIC_MODEL} ${INFERENCE_URL}/tensorrt_test "conv_bn_swish_split_gelu.tar.gz")
endif()
inference_analysis_test(trt_dynamic_shape_test SRCS trt_dynamic_shape_test.cc
......@@ -461,7 +464,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR})
set(TEST_TRT_ERNIE_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test")
if (NOT EXISTS ${TEST_TRT_ERNIE_MODEL})
if (NOT EXISTS ${TEST_TRT_ERNIE_MODEL}/ernie_model_4.tar.gz)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4.tar.gz")
endif()
......@@ -470,7 +473,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4)
set(TEST_TRT_ERNIE_UNSER_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test/ernie_model_4_unserialized/")
if (NOT EXISTS ${TEST_TRT_ERNIE_UNSER_MODEL})
if (NOT EXISTS ${TEST_TRT_ERNIE_UNSER_MODEL}/ernie_model_4_unserialized.tgz)
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4_unserialized.tgz")
endif()
......
......@@ -54,9 +54,6 @@ TEST(PD_AnalysisConfig, use_gpu) {
PD_SwitchIrOptim(config, true);
bool ir_optim = PD_IrOptim(config);
CHECK(ir_optim) << "NO";
PD_EnableMkldnnBfloat16(config);
bool bfloat16_enable = PD_MkldnnBfloat16Enabled(config);
CHECK(!bfloat16_enable) << "NO";
PD_EnableTensorRtEngine(config, 1 << 20, 1, 3, Precision::kFloat32, false,
false);
bool trt_enable = PD_TensorrtEngineEnabled(config);
......
......@@ -90,7 +90,6 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
config.SwitchUseFeedFetchOps(false);
int head_number = 12;
int batch = 1;
int min_seq_len = 1;
int max_seq_len = 128;
......@@ -104,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape},
{"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}};
{"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape},
{"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}};
{"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape},
{"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}};
{"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}};
auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) {
......
......@@ -90,7 +90,6 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
config.SwitchUseFeedFetchOps(false);
int head_number = 12;
int batch = 1;
int min_seq_len = 1;
int max_seq_len = 128;
......@@ -104,17 +103,17 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
{"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape},
{"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}};
{"matmul_0.tmp_0", {batch, min_seq_len, min_seq_len}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape},
{"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}};
{"matmul_0.tmp_0", {batch, max_seq_len, max_seq_len}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape},
{"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}};
{"matmul_0.tmp_0", {batch, opt_seq_len, opt_seq_len}}};
auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) {
......
......@@ -25,12 +25,20 @@ namespace inference {
TEST(quant_int8, resnet50) {
std::string model_dir = FLAGS_infer_model;
AnalysisConfig config;
config.EnableUseGpu(100, 0);
config.EnableUseGpu(1000, 0);
config.SetModel(model_dir);
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 30, 1, 1, AnalysisConfig::Precision::kInt8,
false, false);
std::map<std::string, std::vector<int>> min_input_shape = {
{"image", {1, 1, 3, 3}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"image", {1, 1, 10, 10}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"image", {1, 1, 3, 3}}};
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int channels = 1;
......
......@@ -45,7 +45,7 @@ function(inference_download_and_uncompress INSTALL_DIR URL FILENAME)
endfunction()
set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec")
if(NOT EXISTS ${WORD2VEC_INSTALL_DIR})
if(NOT EXISTS ${WORD2VEC_INSTALL_DIR}/word2vec.inference.model.tar.gz)
inference_download_and_uncompress(${WORD2VEC_INSTALL_DIR} ${INFERENCE_URL} "word2vec.inference.model.tar.gz")
endif()
set(WORD2VEC_MODEL_DIR "${WORD2VEC_INSTALL_DIR}/word2vec.inference.model")
......
......@@ -244,7 +244,6 @@ UNUSED constexpr char CosDoc[] = R"DOC(
Cosine Operator. Computes cosine of x element-wise.
Input range is `(-inf, inf)` and output range is `[-1,1]`.
Return `nan` if input is out of boundary.
$$out = cos(x)$$
......@@ -342,7 +341,9 @@ $$out = \cos^{-1}(x)$$
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of asin operator");
AddInput("X",
"Input of asin operator, an N-D Tensor, with data type float32, "
"float64 or float16.");
AddOutput("Out", "Output of asin operator");
AddComment(R"DOC(
Arcsine Operator.
......@@ -356,7 +357,9 @@ $$out = \sin^{-1}(x)$$
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of atan operator");
AddInput("X",
"Input of atan operator, an N-D Tensor, with data type float32, "
"float64 or float16.");
AddOutput("Out", "Output of atan operator");
AddComment(R"DOC(
Arctangent Operator.
......@@ -1252,4 +1255,14 @@ REGISTER_OP_VERSION(hard_shrink)
"((x < -threshold) + (x > threshold)); after checkpoint: out = "
"x * (((x < -threshold) + (x > threshold)) > 0)"));
REGISTER_OP_VERSION(softplus)
.AddCheckpoint(
R"ROC(add new attributes [beta] and [threshold], and the formula is changed to "
" softplus(x) = \\frac{1}{beta} * \\log(1 + e^{beta * x}) \\\\ \\text{For numerical"
" stability, the implementation reverts to the linear function when: beta * x > threshold.})ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("beta", "The beta value of the new formula", 1.0f)
.NewAttr("threshold", "The threshold value of the new formula",
20.0f));
/* ========================================================================== */
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(
......@@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL(
int16_t>,
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
REGISTER_OP_VERSION(arg_max)
.AddCheckpoint(
R"ROC(
Upgrade argmax add a new attribute [flatten] and modify the attribute of dtype)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("flatten",
"In order to compute the argmax over the flattened array "
"when the "
"argument `axis` in python API is None.",
false)
.ModifyAttr(
"dtype",
"change the default value of dtype, the older version "
"is -1, means return the int64 indices."
"The new version is 3, return the int64 indices directly."
"And supporting the dtype of -1 in new version.",
3));
......@@ -70,6 +70,8 @@ struct VisitDataArgMinMaxFunctor {
auto axis = ctx.Attr<int64_t>("axis");
auto keepdims = ctx.Attr<bool>("keepdims");
const bool& flatten = ctx.Attr<bool>("flatten");
// paddle do not have the scalar tensor, just return the shape [1] tensor
if (flatten) keepdims = true;
// if flatten, will construct the new dims for the cacluate
framework::DDim x_dims;
......@@ -164,15 +166,30 @@ class ArgMinMaxOp : public framework::OperatorWithKernel {
platform::errors::InvalidArgument(
"'axis'(%d) must be less than Rank(X)(%d).", axis, x_dims.size()));
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
if (ctx->IsRuntime()) {
const int& dtype = ctx->Attrs().Get<int>("dtype");
if (dtype == framework::proto::VarType::INT32) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = framework::product(x_dims);
} else {
all_element_num = x_dims[axis];
}
PADDLE_ENFORCE_LE(
all_element_num, INT_MAX,
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'.",
all_element_num, INT_MAX);
}
}
std::vector<int64_t> vec;
if (flatten) {
// if is flatten, will return the only on element
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
}
vec.emplace_back(static_cast<int64_t>(1));
} else {
auto x_rank = x_dims.size();
if (axis < 0) axis += x_rank;
for (int64_t i = 0; i < axis; i++) vec.emplace_back(x_dims[i]);
if (keepdims) {
vec.emplace_back(static_cast<int64_t>(1));
......@@ -194,10 +211,14 @@ class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "Output tensor.");
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
AddAttr<bool>("keepdims", "Keep the dim that to reduce.").SetDefault(false);
AddAttr<int>("dtype", "Keep the dim that to reduce.").SetDefault(-1);
AddAttr<bool>("flatten",
"Flatten the input value, and search the min or max indices")
.SetDefault(false);
AddAttr<int>("dtype",
"(int, 3), the dtype of indices, the indices dtype must be "
"int32, int64."
"default dtype is int64, and proto value is 3.")
.SetDefault(3);
AddComment(string::Sprintf(R"DOC(
%s Operator.
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
REGISTER_OPERATOR(
......@@ -31,3 +32,20 @@ REGISTER_OP_CPU_KERNEL(
int16_t>,
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
uint8_t>);
REGISTER_OP_VERSION(arg_min)
.AddCheckpoint(
R"ROC(
Upgrade argmin add a new attribute [flatten] and modify the attribute of dtype)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("flatten",
"In order to compute the argmin over the flattened array "
"when the "
"argument `axis` in python API is None.",
false)
.ModifyAttr(
"dtype",
"change the default value of dtype, the older version "
"is -1, means return the int64 indices."
"The new version is 3, return the int64 indices directly."
"And supporting the dtype of -1 in new version.",
3));
......@@ -31,6 +31,10 @@ struct BernoulliCudaFunctor {
__host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {}
__host__ __device__ T operator()(const unsigned int n, const T p) const {
// NOTE(zhiqiu): currently, PADDLE_ENFORCE in cuda kernel may print several
// lines of error messages if, and it should be refined.
PADDLE_ENFORCE(p >= 0.0 && p <= 1.0,
"The probability should be >=0 and <= 1, but got %f", p);
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(0.0, 1.0);
......
......@@ -25,10 +25,12 @@ namespace operators {
template <typename T>
inline HOSTDEVICE T BernoulliFunctor(T p, T rand) {
PADDLE_ENFORCE_LE(p, 1, platform::errors::OutOfRange(
"The probability should be <= 1, but got %f", p));
PADDLE_ENFORCE_GE(p, 0, platform::errors::OutOfRange(
"The probability should be >= 1, but got %f", p));
PADDLE_ENFORCE_LE(p, 1.0,
platform::errors::OutOfRange(
"The probability should be <= 1, but got %f", p));
PADDLE_ENFORCE_GE(p, 0.0,
platform::errors::OutOfRange(
"The probability should be >= 0, but got %f", p));
return static_cast<T>(rand < p);
}
......
......@@ -25,25 +25,32 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
"Input(Ids) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutputs("Outputs"),
"Output(Outs) of LookupTableOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInputs("Ids"), true,
platform::errors::InvalidArgument(
"Input(Ids) of LookupTableOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
platform::errors::InvalidArgument(
"Input(W) of LookupTableOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutputs("Outputs"), true,
platform::errors::InvalidArgument(
"Output(Outs) of LookupTableOp should not be null."));
auto ids_dims = ctx->GetInputsDim("Ids");
auto table_dims = ctx->GetInputDim("W");
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
"Only 2 dimensions of the 'Embedding' is supported.");
PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
platform::errors::InvalidArgument(
"Only 2 dimensions of the 'Embedding' is supported."));
for (auto &ids_dim : ids_dims) {
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
"The dimension of the 'Ids' tensor must be 2.");
platform::errors::InvalidArgument(
"The dimension of the 'Ids' tensor must be 2."));
}
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
// for fluid.embedding
auto lookup_table_version =
ctx->Attrs().Get<std::string>("lookup_table_version");
......
......@@ -35,9 +35,30 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
auto is_distributed = context.Attr<bool>("is_distributed");
auto lookup_table_version =
context.Attr<std::string>("lookup_table_version");
operators::distributed::prefetchs(id_names, out_names, embedding_name,
is_distributed, lookup_tables, endpoints,
context, context.scope());
if (lookup_table_version == "lookup_table_v2") {
auto &scope = context.scope();
auto emb_dim =
scope.FindVar(embedding_name)->Get<framework::LoDTensor>().dims()[1];
for (size_t i = 0; i < id_names.size(); ++i) {
auto *id_var = scope.FindVar(id_names[i]);
auto *out_var = scope.FindVar(out_names[i]);
auto *id_tensor = id_var->GetMutable<framework::LoDTensor>();
auto *out_tensor = out_var->GetMutable<framework::LoDTensor>();
auto id_dims = id_tensor->dims();
out_tensor->Resize(framework::make_ddim(
{static_cast<int64_t>(id_dims[0]), static_cast<int64_t>(id_dims[1]),
static_cast<int64_t>(emb_dim)}));
}
}
}
};
......
......@@ -70,6 +70,7 @@ class LoadCombineOpKernel : public framework::OpKernel<T> {
auto out_vars = context.MultiOutputVar("Out");
for (size_t i = 0; i < out_var_names.size(); i++) {
VLOG(4) << "loading tensor: " << out_var_names[i];
PADDLE_ENFORCE_NOT_NULL(
out_vars[i], platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.",
......
......@@ -15,8 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/lookup_table_v2_op.h"
#include <memory>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace paddle {
......@@ -196,3 +196,14 @@ REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel<float>,
REGISTER_OP_CPU_KERNEL(lookup_table_v2_grad,
ops::LookupTableV2GradKernel<float>,
ops::LookupTableV2GradKernel<double>);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(lookup_table_v2)
.AddCheckpoint(
R"ROC(fix lookup_table_v2, add input type `int32`)ROC",
paddle::framework::compatible::OpVersionDesc()
.BugfixWithBehaviorChanged("lookup_table_v2 support input type "
"`int64`; after support input type "
"`int32/int64`"));
/* ========================================================================== */
......@@ -85,6 +85,14 @@ __global__ void LookupTableV2Grad(T *table, const T *output, const int64_t *ids,
}
}
template <typename T>
__global__ void InputTypeCovert(const T *in_ids, const int64_t K,
int64_t *out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = (int64_t)(in_ids[i]);
}
}
template <typename T>
class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
public:
......@@ -101,23 +109,37 @@ class LookupTableV2CUDAKernel : public framework::OpKernel<T> {
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
auto *ids = ids_t->data<int64_t>();
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
dim3 threads(256, 4);
dim3 grids(80, 1);
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);
const int64_t *ids_p = nullptr;
if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else {
ids_p = ids_t->data<int64_t>();
}
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1)
LookupTableV2<
T, 256, 4, 80,
false><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
output, table, ids_p, N, K, D, padding_idx);
else
LookupTableV2<
T, 256, 4, 80,
true><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
output, table, ids, N, K, D, padding_idx);
output, table, ids_p, N, K, D, padding_idx);
}
};
......@@ -139,16 +161,24 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
auto stream = dev_ctx.stream();
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num);
auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace());
// TODO(yuyang18): Strange code here.
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
if (ids->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids->data<int>(), ids_num,
new_rows.MutableData(context.GetPlace()));
} else {
memory::Copy(gpu_place, new_rows.CUDAMutableData(context.GetPlace()),
gpu_place, ids_data, ids_num * sizeof(int64_t), stream);
}
d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value();
......@@ -177,17 +207,32 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel<T> {
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
const int64_t *ids = ids_t->data<int64_t>();
dim3 threads(128, 8);
dim3 grids(8, 1);
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> ids;
ids.resize(K);
const int64_t *ids_p = nullptr;
if (ids_t->type() == framework::proto::VarType::INT32) {
InputTypeCovert<
int><<<grids, threads, 0, context.cuda_device_context().stream()>>>(
ids_t->data<int>(), K, ids.MutableData(context.GetPlace()));
ids_p = ids.MutableData(context.GetPlace());
} else {
ids_p = ids_t->data<int64_t>();
}
const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
dim3 threads(128, 8);
dim3 grids(8, 1);
LookupTableV2Grad<T, 128, 8, 8><<<grids, threads, 0, dev_ctx.stream()>>>(
d_table, d_output, ids, N, K, D);
d_table, d_output, ids_p, N, K, D);
}
}
};
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <string>
#include <vector>
......@@ -45,84 +46,70 @@ class LookupTableV2Kernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
auto id_name = context.InputNames("Ids").front();
auto embedding_name = context.InputNames("W").front();
auto out_name = context.OutputNames("Out").front();
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t ids_numel = ids_t->numel();
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter server
std::vector<int64_t> ids;
ids.reserve(ids_numel);
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, embedding_name, false,
table_names, epmap, context,
context.scope());
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_numel,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i], row_number,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
framework::TensorToVector(*ids_t, &ids);
}
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i], row_number,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number, ids[i]);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(
id_index, 0, "the input key should be exists. But received %d.",
id_index);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(
ids[i], 0,
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0,
"the input key should be exists. But received %d.",
id_index);
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
}
......@@ -151,17 +138,23 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if (is_sparse) {
auto *ids = context.Input<LoDTensor>("Ids");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();
std::vector<int64_t> ids;
ids.reserve(ids_num);
auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
std::vector<int64_t> new_rows;
new_rows.resize(ids_num);
std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
d_table->set_rows(new_rows);
d_table->set_rows(ids);
auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
......@@ -185,11 +178,23 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else {
auto *ids = context.Input<LoDTensor>("Ids");
auto *ids_t = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
int64_t ids_num = ids_t->numel();
std::vector<int64_t> ids;
ids.reserve(ids_num);
if (ids_t->type() == framework::proto::VarType::INT32) {
std::transform(ids_t->data<int>(), ids_t->data<int>() + ids_num,
std::back_inserter(ids),
[&](int id) { return static_cast<int64_t>(id); });
} else {
framework::TensorToVector(*ids_t, &ids);
}
auto *ids_data = ids->data<int64_t>();
auto *ids_data = ids.data();
int64_t N = table_dim[0];
int64_t D = table_dim[1];
......@@ -199,7 +204,7 @@ class LookupTableV2GradKernel : public framework::OpKernel<T> {
memset(d_table_data, 0, d_table->numel() * sizeof(T));
for (int64_t i = 0; i < ids->numel(); ++i) {
for (int64_t i = 0; i < ids_num; ++i) {
if (padding_idx != kNoPadding && ids_data[i] == padding_idx) {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
......
......@@ -33,10 +33,12 @@ class MKLDNNActivationKernel
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
PADDLE_ENFORCE_EQ(x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for X tensor");
PADDLE_ENFORCE_NE(x->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for X tensor");
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
Functor functor;
functor(ctx);
......@@ -50,9 +52,11 @@ class MKLDNNActivationGradKernel
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input OutGrad tensor");
platform::errors::InvalidArgument(
"Wrong layout set for Input OutGrad tensor"));
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input OutGrad tensor");
platform::errors::InvalidArgument(
"Wrong format set for Input OutGrad tensor"));
Functor functor;
functor(ctx);
......@@ -82,7 +86,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
PADDLE_ENFORCE(
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
"Input dim must be with 2, 3 or 4");
platform::errors::Unimplemented("Input dim must be with 2, 3 or 4"));
auto src_tz = framework::vectorize<int64_t>(x->dims());
......
......@@ -262,9 +262,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto *diff_shift = ctx.Output<Tensor>(framework::GradVarName("Bias"));
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input diff_y tensor");
platform::errors::InvalidArgument(
"Wrong layout set for Input diff_y tensor"));
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input diff_y tensor");
platform::errors::InvalidArgument(
"Wrong format set for Input diff_y tensor"));
auto src_tz = paddle::framework::vectorize<int64_t>(x->dims());
auto scale_tz = paddle::framework::vectorize<int64_t>(scale->dims());
......
......@@ -30,10 +30,12 @@ using platform::to_void_cast;
static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
for (auto* input : inputs) {
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor");
PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Input tensor"));
PADDLE_ENFORCE_NE(
input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Input tensor"));
}
}
......@@ -49,7 +51,7 @@ static platform::CPUPlace GetCpuPlace(
const paddle::framework::ExecutionContext& ctx) {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(paddle::platform::is_cpu_place(place),
"It must use CPUPlace.");
platform::errors::InvalidArgument("It must use CPUPlace."));
return BOOST_GET_CONST(platform::CPUPlace, place);
}
......
......@@ -561,7 +561,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
!fuse_residual_conn || !force_fp32_output, true,
"residual fusion does not support force output with fp32");
platform::errors::Unimplemented(
"residual fusion does not support force output with fp32"));
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
......@@ -625,7 +626,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
? dilations.size() == 3 && dilations[0] == 1 &&
dilations[1] == 1 && dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
true, "dilation in convolution is not implemented yet");
true, platform::errors::Unimplemented(
"dilation in convolution is not implemented yet"));
const K* filter_data = filter->data<K>();
auto scale_in_data = ctx.Attr<float>("Scale_in");
......@@ -887,7 +889,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
"The output_grad tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, output_grad->layout()));
PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for output_grad tensor");
platform::errors::InvalidArgument(
"Wrong format set for output_grad tensor"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
......
......@@ -117,7 +117,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
platform::errors::Unimplemented(
"dilation in convolution is not implemented yet"));
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
......
......@@ -83,19 +83,24 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
PADDLE_ENFORCE_EQ(in_x->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input tensor");
PADDLE_ENFORCE_NE(in_x->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input tensor");
PADDLE_ENFORCE_EQ(
in_x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for Input tensor"));
PADDLE_ENFORCE_NE(
in_x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Input tensor"));
PADDLE_ENFORCE_EQ(out_grad->layout(), DataLayout::kMKLDNN,
"Wrong layout set for Input output_grad tensor");
platform::errors::InvalidArgument(
"Wrong layout set for Input output_grad tensor"));
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
"Wrong format set for Input output_grad tensor");
platform::errors::InvalidArgument(
"Wrong format set for Input output_grad tensor"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
"is_test attribute should be set to False in training phase.");
platform::errors::InvalidArgument(
"is_test attribute should be set to False in training phase."));
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
......
......@@ -140,7 +140,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
dout->dims(), dx->dims(),
"The shape of softmax_grad's input and output must be identical.");
platform::errors::InvalidArgument(
"The shape of softmax_grad's input and output must be identical."));
auto dims = dout->dims(); // input and output share the same shape
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), dims.size());
......
......@@ -12,60 +12,90 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/softmax.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/softmax_op.h"
#include "paddle/fluid/platform/cudnn_desc.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace operators {
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor;
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename T>
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Out = context.Output<Tensor>("Out");
// allocate memory on device.
Out->mutable_data<T>(context.GetPlace());
auto dims = X->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_x;
framework::LoDTensor flattened_out;
flattened_x.ShareDataWith(*X).Resize(flattened_dims);
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
math::SoftmaxCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(),
&flattened_x, &flattened_out);
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto* out_data = out->data<T>();
auto dims = x->dims();
const int rank = dims.size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int dim = dims[axis];
const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims);
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
platform::CudnnDataType<T>::kZero(), desc_, out_data));
}
};
template <typename T>
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
auto dims = Out->dims();
auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1);
framework::LoDTensor flattened_out;
framework::LoDTensor flattened_d_out;
framework::LoDTensor flattened_d_x;
flattened_out.ShareDataWith(*Out).Resize(flattened_dims);
flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims);
flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims);
math::SoftmaxGradCUDNNFunctor<T>()(
context.template device_context<platform::CUDADeviceContext>(),
&flattened_out, &flattened_d_out, &flattened_d_x);
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<Tensor>("Out");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
dx->mutable_data<T>(ctx.GetPlace());
auto* dx_data = dx->data<T>();
auto dims = out->dims();
const int rank = dims.size();
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
const int dim = dims[axis];
const int N = SizeToAxis(axis, dims);
const int D = SizeOutAxis(axis, dims);
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle();
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
handle, CUDNN_SOFTMAX_ACCURATE, mode,
platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_,
dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_, dx_data));
}
};
......
......@@ -53,13 +53,6 @@ class SoftmaxOp : public framework::OperatorWithKernel {
"Attr(axis) value should be in range [-R, R-1], "
"R is the rank of Input(X)."));
auto use_cudnn = ctx->Attrs().Get<bool>("use_cudnn");
if (axis != rank_x - 1 && axis != -1) {
PADDLE_ENFORCE_EQ(use_cudnn, false,
platform::errors::InvalidArgument(
"CUDNN kernel only support axis as -1."));
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
......
......@@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unsqueeze_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -327,6 +329,7 @@ REGISTER_OPERATOR(unsqueeze2_grad, ops::Unsqueeze2GradOp,
REGISTER_OP_CPU_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -334,12 +337,14 @@ REGISTER_OP_CPU_KERNEL(
unsqueeze_grad,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
unsqueeze2, ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -347,6 +352,7 @@ REGISTER_OP_CPU_KERNEL(
unsqueeze2_grad,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -21,6 +21,7 @@ REGISTER_OP_CUDA_KERNEL(
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -30,6 +31,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -38,6 +40,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -47,6 +50,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -57,7 +57,11 @@ void BindFleetWrapper(py::module* m) {
.def("get_cache_threshold", &framework::FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &framework::FleetWrapper::CacheShuffle)
.def("save_cache", &framework::FleetWrapper::SaveCache)
.def("save_model_with_whitelist",
&framework::FleetWrapper::SaveWithWhitelist)
.def("load_model", &framework::FleetWrapper::LoadModel)
.def("load_table_with_whitelist",
&framework::FleetWrapper::LoadWithWhitelist)
.def("clear_model", &framework::FleetWrapper::ClearModel)
.def("clear_one_table", &framework::FleetWrapper::ClearOneTable)
.def("stop_server", &framework::FleetWrapper::StopServer)
......
......@@ -334,8 +334,7 @@ void BindGlobalValueGetterSetter(pybind11::module *module) {
} while (0)
static void RegisterGlobalVarGetterSetter() {
REGISTER_PRIVATE_GLOBAL_VAR(/*is_writable=*/false, FLAGS_use_mkldnn,
FLAGS_free_idle_chunk,
REGISTER_PRIVATE_GLOBAL_VAR(/*is_writable=*/false, FLAGS_free_idle_chunk,
FLAGS_free_when_no_cache_hit);
REGISTER_PUBLIC_GLOBAL_VAR(
......@@ -349,7 +348,7 @@ static void RegisterGlobalVarGetterSetter() {
FLAGS_init_allocated_mem, FLAGS_initial_cpu_memory_in_mb,
FLAGS_memory_fraction_of_eager_deletion, FLAGS_use_pinned_memory,
FLAGS_benchmark, FLAGS_inner_op_parallelism, FLAGS_tracer_profile_fname,
FLAGS_paddle_num_threads);
FLAGS_paddle_num_threads, FLAGS_use_mkldnn);
#ifdef PADDLE_WITH_CUDA
REGISTER_PUBLIC_GLOBAL_VAR(
......
......@@ -29,8 +29,10 @@ function(train_test TARGET_NAME)
PROPERTIES DEPENDS test_${TARGET_NAME})
set_tests_properties(test_train_${TARGET_NAME}${arg}
PROPERTIES LABELS "RUN_TYPE=DIST")
set_tests_properties(test_train_${TARGET_NAME}${arg}
PROPERTIES TIMEOUT 150)
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_train_${TARGET_NAME}${arg}
PROPERTIES TIMEOUT 150)
endif()
endforeach()
endfunction(train_test)
......
......@@ -42,7 +42,6 @@ requirements:
- nltk
- scipy
- requests
- pyyaml
- pillow
- graphviz
- protobuf
......@@ -62,7 +61,6 @@ requirements:
- nltk
- scipy
- requests
- pyyaml
- pillow
- graphviz
- protobuf
......@@ -89,13 +87,11 @@ about:
pip install /package/objgraph-3.4.1.tar.gz
pip install /package/prettytable-0.7.tar.gz
pip install /package/rarfile-3.0.tar.gz --no-deps
pip install /package/funcsigs-1.0.2.tar.gz
"""
self.blt_const = r"""
pip install C:\package\objgraph-3.4.1.tar.gz
pip install C:\package\prettytable-0.7.tar.gz
pip install C:\package\funcsigs-1.0.2.tar.gz
pip install C:\package\rarfile-3.0.tar.gz --no-deps
git clone https://github.com/PaddlePaddle/recordio.git
cd recordio\python
......
......@@ -19,10 +19,14 @@ rem =================================================
rem Paddle CI Task On Windows Platform
rem =================================================
rem -------clean up environment-----------
set work_dir=%cd%
if exist build rmdir build /s/q
mkdir build
cd /d build
tree .
dir paddle\fluid\pybind\Release
taskkill /f /im op_function_generator.exe 2>NUL
rem ------initialize the virtual environment------
if not defined PYTHON_ROOT set PYTHON_ROOT=C:\Python37
......@@ -59,13 +63,12 @@ if not defined WITH_INFERENCE_API_TEST set WITH_INFERENCE_API_TEST=OFF
if not defined WITH_TPCACHE set WITH_TPCACHE=ON
rem ------set cache third_party------
set cache_dir=%work_dir%\..\cache
set cache_dir=%work_dir:Paddle=cache%
dir %cache_dir%
set INFERENCE_DEMO_INSTALL_DIR=%cache_dir:\=/%/inference_demo
if not exist %cache_dir%\tools (
git clone https://github.com/zhouwei25/tools.git %cache_dir%\tools
if %ERRORLEVEL% NEQ 0 exit /b %ERRORLEVEL%
)
if "%WITH_TPCACHE%"=="OFF" (
......@@ -125,6 +128,8 @@ echo ========================================
echo Step 1. Cmake ...
echo ========================================
for /F %%# in ('wmic os get localdatetime^|findstr 20') do set start=%%#
set start=%start:~4,10%
echo cmake .. -G "Visual Studio 14 2015 Win64" -DWITH_AVX=%WITH_AVX% -DWITH_GPU=%WITH_GPU% -DWITH_MKL=%WITH_MKL% ^
-DWITH_TESTING=%WITH_TESTING% -DWITH_PYTHON=%WITH_PYTHON% -DCUDA_TOOLKIT_ROOT_DIR=%CUDA_TOOLKIT_ROOT_DIR% ^
-DON_INFER=%ON_INFER% -DWITH_INFERENCE_API_TEST=%WITH_INFERENCE_API_TEST% -DTHIRD_PARTY_PATH=%THIRD_PARTY_PATH% ^
......@@ -150,7 +155,7 @@ call "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\vcvarsall.bat" amd6
set build_times=1
:build_tp
echo Build third_party for %build_times% time:
echo Build third_party the %build_times% time:
msbuild /m /p:Configuration=Release /verbosity:quiet third_party.vcxproj
if %ERRORLEVEL% NEQ 0 (
set /a build_times=%build_times%+1
......@@ -165,7 +170,7 @@ echo Build third_party successfully!
set build_times=1
:build_paddle
echo Build Paddle for %build_times% time:
echo Build Paddle the %build_times% time:
msbuild /m /p:Configuration=Release /verbosity:minimal paddle.sln
if %ERRORLEVEL% NEQ 0 (
set /a build_times=%build_times%+1
......@@ -176,7 +181,9 @@ if %ERRORLEVEL% NEQ 0 (
goto :build_paddle
)
)
echo Build Paddle successfully!
goto:eof
:build_error
......@@ -189,6 +196,17 @@ rem ----------------------------------------------------------------------------
echo ========================================
echo Step 3. Test pip install whl package ...
echo ========================================
for /F %%# in ('wmic os get localdatetime^|findstr 20') do set end=%%#
set end=%end:~4,10%
call :timestamp "%start%" "%end%" "Build"
tree /F %cd%\fluid_inference_install_dir\paddle
%cache_dir%\tools\busybox64.exe du -h -d 0 %cd%\fluid_inference_install_dir\paddle\lib > lib_size.txt
set /p libsize=< lib_size.txt
for /F %%i in ("%libsize%") do echo "Windows FLuid_Inference Size: %%i"
%cache_dir%\tools\busybox64.exe du -h -d 0 %cd%\python\dist > whl_size.txt
set /p whlsize=< whl_size.txt
for /F %%i in ("%whlsize%") do echo "Windows PR whl Size: %%i"
dir /s /b python\dist\*.whl > whl_file.txt
set /p PADDLE_WHL_FILE_WIN=< whl_file.txt
......@@ -198,7 +216,7 @@ pip install -U %PADDLE_WHL_FILE_WIN% --user
if %ERRORLEVEL% NEQ 0 (
call paddle_winci\Scripts\deactivate.bat 2>NUL
echo pip install whl package failed!
exit /b 3
exit /b 1
)
python %work_dir%\paddle\scripts\installation_validate.py
......@@ -207,7 +225,7 @@ goto:eof
:test_whl_pacakage_error
call paddle_winci\Scripts\deactivate.bat 2>NUL
echo Test import paddle failed, will exit!
exit /b 3
exit /b 1
rem ---------------------------------------------------------------------------------------------
:unit_test
......@@ -215,6 +233,8 @@ echo ========================================
echo Step 4. Running unit tests ...
echo ========================================
for /F %%# in ('wmic os get localdatetime^|findstr 20') do set start=%%#
set start=%start:~4,10%
dir %THIRD_PARTY_PATH:/=\%\install\openblas\lib
dir %THIRD_PARTY_PATH:/=\%\install\openblas\bin
dir %THIRD_PARTY_PATH:/=\%\install\zlib\bin
......@@ -237,15 +257,18 @@ echo ========================================
echo Step 5. Testing fluid library for inference ...
echo ========================================
cd %work_dir%\paddle\fluid\inference\api\demo_ci
for /F %%# in ('wmic os get localdatetime^|findstr 20') do set end=%%#
set end=%end:~4,10%
call :timestamp "%start%" "%end%" "TestCases Total"
cd %work_dir%\paddle\fluid\inference\api\demo_ci
%cache_dir%\tools\busybox64.exe bash run.sh %work_dir:\=/% %WITH_MKL% %WITH_GPU% %cache_dir:\=/%/inference_demo
goto:eof
:test_inference_error
call paddle_winci\Scripts\deactivate.bat 2>NUL
echo Testing fluid library for inference failed!
exit /b 5
exit /b 1
rem ---------------------------------------------------------------------------------------------
:check_change_of_unittest
......@@ -253,7 +276,6 @@ echo ========================================
echo Step 6. Check whether deleting a unit test ...
echo ========================================
set PATH=%PYTHON_ROOT%;%PATH%
cd /d %work_dir%\build
echo set -ex> check_change_of_unittest.sh
echo GITHUB_API_TOKEN=%GITHUB_API_TOKEN% >> check_change_of_unittest.sh
......@@ -325,6 +347,43 @@ call paddle_winci\Scripts\deactivate.bat 2>NUL
exit /b 1
:timestamp
echo on
setlocal enabledelayedexpansion
set start=%~1
set dd=%start:~2,2%
set /a dd=100%dd%%%100
set hh=%start:~4,2%
set /a hh=100%hh%%%100
set nn=%start:~6,2%
set /a nn=100%nn%%%100
set ss=%start:~8,2%
set /a ss=100%ss%%%100
set /a start_sec=dd*86400+hh*3600+nn*60+ss
echo %start_sec%
set end=%~2
set dd=%end:~2,2%
set /a dd=100%dd%%%100
if %start:~0,2% NEQ %end:~0,2% (
set month_day=0
for %%i in (01 03 05 07 08 10 12) DO if %%i EQU %start:~0,2% set month_day=31
for %%i in (04 06 09 11) DO if %%i EQU %start:~0,2% set month_day=30
for %%i in (02) DO if %%i EQU %start:~0,2% set month_day=28
set /a dd=%dd%+!month_day!
)
set hh=%end:~4,2%
set /a hh=100%hh%%%100
set nn=%end:~6,2%
set /a nn=100%nn%%%100
set ss=%end:~8,2%
set /a ss=100%ss%%%100
set /a end_secs=dd*86400+hh*3600+nn*60+ss
set /a cost_secs=end_secs-start_sec
echo "Windows %~3 Time: %cost_secs%s"
goto:eof
rem ---------------------------------------------------------------------------------------------
:success
echo ========================================
......@@ -340,7 +399,7 @@ taskkill /f /im git-remote-https.exe 2>NUL
taskkill /f /im vctip.exe 2>NUL
taskkill /f /im cvtres.exe 2>NUL
taskkill /f /im rc.exe 2>NUL
taskkill /f /im %cd%\paddle\fluid\pybind\Release\op_function_generator.exe 2>NUL
taskkill /f /im op_function_generator.exe 2>NUL
taskkill /f /im python.exe 2>NUL
call paddle_winci\Scripts\deactivate.bat 2>NUL
taskkill /f /im python.exe 2>NUL
......
......@@ -959,7 +959,7 @@ set +x
retry_unittests_record="$retry_unittests_record$failed_test_lists"
failed_test_lists_ult=`echo "${failed_test_lists}" |grep -Po '[^ ].*$'`
read retry_unittests <<< $(echo "$failed_test_lists" | grep -oEi "\-.+\(\w+\)" | sed 's/(.\+)//' | sed 's/- //' )
read retry_unittests <<< $(echo "$failed_test_lists" | grep -oEi "\-.+\(.+\)" | sed 's/(.\+)//' | sed 's/- //' )
echo "========================================="
echo "This is the ${exec_time_array[$exec_times]} time to re-run"
echo "========================================="
......@@ -1395,6 +1395,26 @@ function example() {
fi
}
function summary_check_problems() {
set +x
local check_style_code=$1
local example_code=$2
if [ $check_style_code -ne 0 -o $example_code -ne 0 ];then
echo "========================================"
echo "summary problems:"
echo "========================================"
if [ $check_style_code -ne 0 ];then
echo "- Check code style failed! Please check the log and fix problems."
fi
if [ $example_code -ne 0 ];then
echo "- Check example code failed! Please check the log and fix problems."
fi
[ $check_style_code -ne 0 ] && exit $check_style_code
[ $example_code -ne 0 ] && exit $example_code
fi
set -x
}
function main() {
local CMD=$1
local parallel_number=$2
......@@ -1407,12 +1427,15 @@ function main() {
cmake_gen_and_build ${PYTHON_ABI:-""} ${parallel_number}
;;
build_and_check)
check_style
$(check_style >&2)
check_style_code=$?
generate_upstream_develop_api_spec ${PYTHON_ABI:-""} ${parallel_number}
cmake_gen_and_build ${PYTHON_ABI:-""} ${parallel_number}
check_sequence_op_unittest
generate_api_spec ${PYTHON_ABI:-""} "PR"
example
$(example >&2)
example_code=$?
summary_check_problems $check_style_code $example_code
assert_api_spec_approvals
;;
build)
......
......@@ -95,7 +95,6 @@ if (WITH_TESTING)
add_subdirectory(paddle/fluid/tests)
add_subdirectory(paddle/fluid/contrib/tests)
add_subdirectory(paddle/fluid/contrib/slim/tests)
add_subdirectory(paddle/incubate/hapi/tests)
endif()
install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
DESTINATION opt/paddle/share/wheels
......
......@@ -256,12 +256,16 @@ from .device import get_device
# from .tensor.tensor import LoDTensor #DEFINE_ALIAS
# from .tensor.tensor import LoDTensorArray #DEFINE_ALIAS
from . import incubate
from .incubate import hapi
from .fluid.dygraph.base import enable_dygraph as disable_static #DEFINE_ALIAS
from .fluid.dygraph.base import disable_dygraph as enable_static #DEFINE_ALIAS
from .fluid.framework import in_dygraph_mode as in_dynamic_mode #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad #DEFINE_ALIAS
from .fluid.dygraph.base import no_grad_ as no_grad #DEFINE_ALIAS
from . import jit
from . import static
# high-level api
from .hapi import Model
from .hapi import callbacks
import paddle.text
import paddle.vision
......@@ -196,3 +196,14 @@ def cluster_files_reader(files_pattern,
yield line
return reader
def _check_exists_and_download(path, url, md5, module_name, download=True):
if path and os.path.exists(path):
return path
if download:
return paddle.dataset.common.download(url, module_name, md5)
else:
raise ValueError('{} not exists and auto download disabled'.format(
path))
......@@ -36,7 +36,7 @@ import tarfile
import gzip
from collections import defaultdict
import paddle.dataset.common
import paddle
import paddle.compat as cpt
__all__ = [
......
......@@ -13,9 +13,11 @@
# limitations under the License.
# TODO: define the functions to manipulate devices
import re
from paddle.fluid import core
from paddle.fluid import framework
import re
from paddle.fluid.dygraph.parallel import ParallelEnv
__all__ = [
'get_cudnn_version',
......@@ -81,8 +83,8 @@ def set_device(device):
.. code-block:: python
import paddle
paddle.enable_imperative()
paddle.fluid.dygraph.set_device("gpu:0")
paddle.disable_static()
paddle.set_device("cpu")
x1 = paddle.ones(name='x1', shape=[1, 2], dtype='int32')
x2 = paddle.zeros(name='x2', shape=[1, 2], dtype='int32')
data = paddle.stack([x1,x2], axis=1)
......@@ -90,18 +92,28 @@ def set_device(device):
lower_device = device.lower()
if lower_device == 'cpu':
place = core.CPUPlace()
framework._set_expected_place(place)
elif lower_device == 'gpu':
if not core.is_compiled_with_cuda():
raise ValueError(
"The device should not be 'gpu', " \
"since PaddlePaddle is not compiled with CUDA")
place = core.CUDAPlace(ParallelEnv().dev_id)
else:
avaliable_device = ((lower_device == 'cpu') or
re.match(r'gpu:\d+', lower_device))
avaliable_device = re.match(r'gpu:\d+', lower_device)
if not avaliable_device:
raise ValueError(
"The device must be a string which is like 'cpu' or 'gpu:0'")
"The device must be a string which is like 'cpu', 'gpu' or 'gpu:0'"
)
if not core.is_compiled_with_cuda():
raise ValueError(
"The device should not be {}, since PaddlePaddle is " \
"not compiled with CUDA".format(avaliable_device))
device_info_list = device.split(':', 1)
device_id = device_info_list[1]
device_id = int(device_id)
place = core.CUDAPlace(device_id)
framework._set_expected_place(place)
framework._set_expected_place(place)
return place
def get_device():
......@@ -116,8 +128,8 @@ def get_device():
.. code-block:: python
import paddle
paddle.enable_imperative()
device = paddle.fluid.dygraph.get_device()
paddle.disable_static()
device = paddle.get_device()
"""
device = ''
......
......@@ -73,20 +73,21 @@ def broadcast(tensor, src, group=0):
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.broadcast(data, 1)
out = data.numpy()
# [[1, 2, 3], [1, 2, 3]]
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.broadcast(data, 1)
out = data.numpy()
# [[1, 2, 3], [1, 2, 3]]
"""
if in_dygraph_mode():
return core.ops.c_broadcast(tensor, tensor, 'root', src,
......@@ -129,21 +130,22 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=0):
Examples:
.. code-block:: python
import paddle
from paddle.distributed import ReduceOp
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.all_reduce(data)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
import numpy as np
import paddle
from paddle.distributed import ReduceOp
from paddle.distributed import init_parallel_env
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.all_reduce(data)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
"""
if in_dygraph_mode():
if op == ReduceOp.SUM:
......@@ -204,20 +206,21 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=0):
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.reduce(data, 0)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data = np.array([[4, 5, 6], [4, 5, 6]])
else:
np_data = np.array([[1, 2, 3], [1, 2, 3]])
data = paddle.to_tensor(np_data)
paddle.distributed.reduce(data, 0)
out = data.numpy()
# [[5, 7, 9], [5, 7, 9]]
"""
if in_dygraph_mode():
if op == ReduceOp.SUM:
......@@ -286,25 +289,26 @@ def all_gather(tensor_list, tensor, group=0):
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
tensor_list = []
if paddle.ParallelEnv().local_rank == 0:
np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data1)
else:
np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
out = paddle.distributed.all_gather(tensor_list, data2)
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
tensor_list = []
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data1 = np.array([[4, 5, 6], [4, 5, 6]])
np_data2 = np.array([[4, 5, 6], [4, 5, 6]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data1)
else:
np_data1 = np.array([[1, 2, 3], [1, 2, 3]])
np_data2 = np.array([[1, 2, 3], [1, 2, 3]])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
paddle.distributed.all_gather(tensor_list, data2)
"""
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
......@@ -359,25 +363,26 @@ def scatter(tensor, tensor_list=None, src=0, group=0):
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
if paddle.ParallelEnv().local_rank == 0:
np_data1 = np.array([7, 8, 9])
np_data2 = np.array([10, 11, 12])
else:
np_data1 = np.array([1, 2, 3])
np_data2 = np.array([4, 5, 6])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
if paddle.ParallelEnv().local_rank == 0:
paddle.distributed.scatter(data1, src=1)
else:
paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
out = data1.numpy()
import numpy as np
import paddle
from paddle.distributed import init_parallel_env
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
if paddle.distributed.ParallelEnv().local_rank == 0:
np_data1 = np.array([7, 8, 9])
np_data2 = np.array([10, 11, 12])
else:
np_data1 = np.array([1, 2, 3])
np_data2 = np.array([4, 5, 6])
data1 = paddle.to_tensor(np_data1)
data2 = paddle.to_tensor(np_data2)
if paddle.distributed.ParallelEnv().local_rank == 0:
paddle.distributed.scatter(data1, src=1)
else:
paddle.distributed.scatter(data1, tensor_list=[data1, data2], src=1)
out = data1.numpy()
"""
op_type = 'c_scatter'
global _default_group
......@@ -425,13 +430,13 @@ def barrier(group=0):
Examples:
.. code-block:: python
import paddle
import paddle.prepare_context as prepare_context
import paddle
from paddle.distributed import init_parallel_env
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.ParallelEnv().dev_id)
prepare_context()
paddle.distributed.barrier()
paddle.disable_static()
paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id)
init_parallel_env()
paddle.distributed.barrier()
"""
op_type = 'barrier'
temp = paddle.fill_constant([1], dtype="int32", value="1")
......
......@@ -22,8 +22,6 @@ from .runtime_factory import RuntimeFactory
from .util_factory import UtilFactory
from paddle.fluid.wrapped_decorator import wrap_decorator
#__all__ = ['Fleet']
def _inited_runtime_handler_(func):
def __impl__(*args, **kwargs):
......@@ -43,65 +41,123 @@ inited_runtime_handler = wrap_decorator(_inited_runtime_handler_)
class Fleet(object):
"""
Unified API for distributed training of PaddlePaddle
Please reference the https://github.com/PaddlePaddle/Fleet for details
Please reference the https://github.com/PaddlePaddle/FleetX for details
Returns:
Fleet: A Fleet instance
Examples:
Example for collective training:
.. code-block:: python
import paddle.distributed.fleet as fleet
role = fleet.role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
# do distributed training
Example for parameter server training:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
if fleet.is_first_worker():
print("this is first worker")
print("current node index: {}".format(fleet.worker_index()))
print("total number of worker num: {}".format(fleet.worker_num()))
if fleet.is_worker():
print("this is worker")
print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True)))
print("server num: {}".format(fleet.server_num()))
print("server endpoints: {}".format(fleet.server_endpoints(to_string=True)))
if fleet.is_server():
print("this is server")
fleet.stop_worker()
"""
def __init__(self):
self._runtime_handle = None
self._util = None
self._role_maker = None
self.strategy_compiler = None
self._is_collective = False
self._runtime_handle = None
self._util = None
def init(self, role_maker=None, is_collective=False):
"""
Initialize role_maker in Fleet.
This function is responsible for the distributed architecture
what you want to run your code behind,such as Transpiler,
Collective in PaddleCloudRoleMaker or UserDefinedRoleMaker
This function is responsible for the distributed architecture
what you want to run your code behind.
Args:
role_maker (RoleMakerBase, optional): A ``RoleMakerBase`` containing the configuration
of environment variables related to distributed training.If you did not initialize
the rolemaker by yourself, it will be automatically initialized to PaddleRoleMaker.
The default value is None.
is_collective (Boolean, optional): A ``Boolean`` variable determines whether the program
runs on the CPU or GPU. False means set distributed training using CPU, and True means
GPU.The default value is False.The default value is False.
Returns:
None
Examples1:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
Examples2:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init(is_collective=True)
Examples3:
.. code-block:: python
import paddle.distributed.fleet as fleet
role = fleet.PaddleCloudRoleMaker
fleet.init(role)
"""
if isinstance(role_maker, RoleMakerBase):
self._role_maker = role_maker
elif role_maker == None:
if role_maker is None:
if isinstance(is_collective, bool):
self._is_collective = is_collective
self._role_maker = PaddleCloudRoleMaker(
is_collective=self._is_collective)
else:
raise ValueError(
"Something wrong occurred, please check whether is_collective is bool value"
)
"`is_collective` should be instance of `bool`, but got {}".
format(type(is_collective)))
else:
raise ValueError(
"Something wrong occurred, please check whether rolemaker is instance of RoleMakerBase"
)
if isinstance(role_maker, RoleMakerBase):
self._role_maker = role_maker
else:
raise ValueError(
"`role_maker` should be subclass of `RoleMakerBase`, but got {}".
format(type(role_maker)))
self.strategy_compiler = StrategyCompiler()
return None
......@@ -113,6 +169,14 @@ class Fleet(object):
bool: True if this is the first node of worker,
False if not.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.is_first_worker()
"""
return self._role_maker.is_first_worker()
......@@ -122,6 +186,14 @@ class Fleet(object):
Returns:
int: node id
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.worker_index()
"""
return self._role_maker.worker_index()
......@@ -131,6 +203,14 @@ class Fleet(object):
Returns:
int: worker numbers
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.worker_num()
"""
return self._role_maker.worker_num()
......@@ -141,15 +221,31 @@ class Fleet(object):
Returns:
bool: True if this is a node of worker,
False if not.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.is_worker()
"""
return self._role_maker.is_worker()
def worker_endpoints(self, to_string=False):
"""
Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
Get current worker endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"].
Returns:
list/string: server endpoints
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.worker_endpoints()
"""
'''
if to_string:
......@@ -165,6 +261,12 @@ class Fleet(object):
Returns:
int: server number
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.server_num()
"""
return len(self._role_maker.get_pserver_endpoints())
......@@ -174,6 +276,14 @@ class Fleet(object):
Returns:
int: node id
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.server_index()
"""
return self._role_maker.server_index()
......@@ -183,14 +293,20 @@ class Fleet(object):
Returns:
list/string: server endpoints
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.server_endpoints()
"""
'''
if to_string:
return ",".join(self._role_maker.get_pserver_endpoints())
else:
return self._role_maker.get_pserver_endpoints()
'''
return ["127.0.0.1:1001", "127.0.0.1:1002"]
def is_server(self):
"""
......@@ -199,6 +315,14 @@ class Fleet(object):
Returns:
bool: True if this is a node of server,
False if not.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
fleet.is_server()
"""
return self._role_maker.is_server(
) or self._role_maker._is_heter_worker()
......@@ -208,6 +332,19 @@ class Fleet(object):
"""
Utility functions that can be used under certain runtime
return util
Returns:
UtilBase: instance of UtilBase, can use distributed ops/tools easily.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
util = fleet.util
files = ["1.log", "2.log", "3.log", "4.log"]
files = util.get_file_shard()
"""
return self._util
......@@ -215,41 +352,114 @@ class Fleet(object):
def util(self, util):
"""
Set Utility functions for userd-defined runtime
set util
Returns:
None
"""
self._util = util
def barrier_worker(self):
"""
barrier between workers
barrier all workers
Returns:
None
"""
self._role_maker.barrier_worker()
@inited_runtime_handler
def init_worker(self):
"""
init worker
initialize `Communicator` for parameter server training.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_worker()
"""
self._runtime_handle._init_worker()
@inited_runtime_handler
def init_server(self, *args, **kwargs):
"""
init server
init_server executor to initialize startup program,
if the `args` is not empty, it will run load_persistables for increment training.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_server()
"""
self._runtime_handle._init_server(*args, **kwargs)
@inited_runtime_handler
def run_server(self):
"""
run server
run server will run pserver main program with executor.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
if fleet.is_server():
fleet.init_server()
"""
self._runtime_handle._run_server()
@inited_runtime_handler
def stop_worker(self):
"""
stop worker
stop `Communicator` and give training complete notice to parameter server.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_server()
"""
self._runtime_handle._stop_worker()
......@@ -260,27 +470,98 @@ class Fleet(object):
target_vars,
main_program=None,
export_for_deployment=True):
"""
save inference model for inference.
Returns:
None
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
# build net
# fleet.distributed_optimizer(...)
fleet.init_server()
"""
self._runtime_handle._save_inference_model(
executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment)
def save_persistables(self, executor, dirname, main_program=None):
"""
saves all persistable variables from :code:`main_program` to
the folder :code:`dirname`. You can refer to
The :code:`dirname` is used to specify the folder where persistable variables
are going to be saved. If you would like to save variables in separate
files, set :code:`filename` None.
Args:
executor(Executor): The executor to run for saving persistable variables.
You can refer to :ref:`api_guide_executor_en` for
more details.
dirname(str, optional): The saving directory path.
When you need to save the parameter to the memory, set it to None.
main_program(Program, optional): The program whose persistbale variables will
be saved. Default: None.
Returns:
None
Examples:
.. code-block:: text
import paddle.distributed.fleet as fleet
import paddle.fluid as fluid
fleet.init()
# build net
# fleet.distributed_optimizer(...)
exe = fluid.Executor(fluid.CPUPlace())
fleet.save_persistables(exe, "dirname", fluid.default_main_program())
"""
self._runtime_handle._save_persistables(executor, dirname, main_program)
def distributed_optimizer(self, optimizer, strategy=None):
"""
distirbuted_optimizer
Optimizer for distributed training.
For the distributed training, this method would rebuild a new instance of DistributedOptimizer.
Which has basic Optimizer function and special features for distributed training.
Args:
optimizer(Optimizer): The executor to run for init server.
strategy(DistributedStrategy): Extra properties for distributed optimizer.
Returns:
Fleet instance with minimize interface like optimizers
Fleet: instance of fleet.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
role = fleet.role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
import paddle.distributed.fleet as fleet
role = fleet.role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
"""
self.user_defined_optimizer = optimizer
if strategy == None:
......@@ -317,23 +598,25 @@ class Fleet(object):
``fetch_list`` before run, see details in ``Executor``.
Examples:
import paddle
import paddle.distributed.fleet as fleet
.. code-block:: python
fc_1 = paddle.layers.fc(input=input_x, size=hid_dim, act='tanh')
fc_2 = paddlen.layers.fc(input=fc_1, size=hid_dim, act='tanh')
prediction = paddle.layers.fc(input=[fc_2], size=label_dim, act='softmax')
cost = paddle.layers.cross_entropy(input=prediction, label=input_y)
avg_cost = paddle.layers.mean(x=cost)
import paddle
import paddle.distributed.fleet as fleet
role = fleet.role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
fc_1 = paddle.fluid.layers.fc(input=input_x, size=hid_dim, act='tanh')
fc_2 = paddle.fluid.layers.fc(input=fc_1, size=hid_dim, act='tanh')
prediction = paddle.fluid.layers.fc(input=[fc_2], size=label_dim, act='softmax')
cost = paddle.fluid.layers.cross_entropy(input=prediction, label=input_y)
avg_cost = paddle.fluid.layers.mean(x=cost)
role = fleet.role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
strategy = fleet.DistributedStrategy()
optimizer = paddle.optimizer.SGD(learning_rate=0.001)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
# for more examples, please reference https://github.com/PaddlePaddle/Fleet
# for more examples, please reference https://github.com/PaddlePaddle/FleetX
"""
context = {}
......
......@@ -154,15 +154,16 @@ class ParameterServerRuntime(RuntimeBase):
kwargs["sparse_attrs"] = get_sparse_attrs()
return kwargs
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_lr_ops, _has_global_step
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import \
SyncStrategy, GeoStrategy
trainer_config = self.async_strategy.get_trainer_runtime_config()
lrs = _get_lr_ops(self.origin_main_program)
if len(lrs) > 0:
lrs = _has_global_step(_get_lr_ops(self.origin_main_program))
if lrs:
kwargs = {"need_global_step": "1"}
else:
kwargs = {"need_global_step": "0"}
......
......@@ -29,13 +29,13 @@ __all__ = ["init_parallel_env"]
ParallelStrategy = core.ParallelStrategy
def init_parallel_env(backend='nccl'):
def init_parallel_env():
"""
Initialize parallel training environments in dynamic mode.
Initialize parallel training environment in dynamic graph mode.
Args:
backend(str, optional): The backend to communication between multiple devices.
Now only support ``nccl`` . Default value is ``nccl`` .
.. note::
Now only supports initializing the GPU parallel training
environment and using NCCL for communication.
Returns:
None
......@@ -89,14 +89,12 @@ def init_parallel_env(backend='nccl'):
dist.spawn(train)
"""
# 1. input check
if not isinstance(backend, six.string_types):
raise TypeError("input `backend` type error, expected type is str, "
"but received type is %s." % type(backend))
if cpt.to_text(backend) != 'nccl':
raise ValueError(
"backend `%s` is not supported, now only supports `nccl` backend." %
backend)
# 1. gpu check
if not core.is_compiled_with_cuda():
raise NotImplementedError(
"Cannot initialize parallel environment in CPU-only version, now only "
"supports initializing the GPU parallel environment. Please recompile "
"or reinstall paddle with GPU support.")
# 2. check env
def _check_var_exists(var_name):
......@@ -112,30 +110,28 @@ def init_parallel_env(backend='nccl'):
_check_var_exists("PADDLE_TRAINERS_NUM")
_check_var_exists("PADDLE_TRAINER_ENDPOINTS")
# 3. init ParallelStrategy
# 3. init NCCL ParallelStrategy
strategy = ParallelStrategy()
if cpt.to_text(backend) == 'nccl':
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)
# init nccl context
parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
if parallel_helper._is_parallel_ctx_initialized():
warnings.warn("The parallel environment has been initialized.")
strategy.nranks = ParallelEnv().world_size
strategy.local_rank = ParallelEnv().rank
strategy.trainer_endpoints = ParallelEnv().trainer_endpoints
strategy.current_endpoint = ParallelEnv().current_endpoint
if strategy.nranks < 2:
return
# NOTE(chenweihang): [ why config global place here? ]
# the dygraph mode will be set to default mode,
# users will not call `dygraph.guard` or `enable_dygraph`
# directly, if they want to switch default place,
# they need to call a function to change default place,
# here just set correctly place to users
place = core.CUDAPlace(ParallelEnv().device_id)
_set_expected_place(place)
# init nccl context
parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place))
parallel_helper._init_parallel_ctx()
def get_rank():
......@@ -163,7 +159,7 @@ def get_rank():
def get_world_size():
"""
The number of trainers (number of processes participating in current job).
Returns the number of trainers (number of processes participating in current job).
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.
......
......@@ -236,8 +236,6 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
func (function): The target function is called by spawned process.
This function need to be able to pickled, so it must be defined
at the top level of a module.
This function should be called as ``func(i, *args)``, ``i`` is
the process index and ``args`` contains other arguments as tuple.
args (tuple, optional): Arguments passed to ``func``.
nprocs (int, optional): Number of processed to start. Default: -1.
when nprocs is -1, the available device will be obtained from
......@@ -246,8 +244,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
variable CUDA_VISIBLE_DEVICES; If use CPU, the currently available
CPU number is obtained from the environment variable CPU_NUM.
For example, export CPU_NUM=4, if the environment variable is not set,
the executor will add the variable to the environment variable and
set its value to 1.
the spawn method will add default value to the environment variable
and set its value to 1.
join (bool, optional): Perform a blocking join on all spawned processes.
Default: True.
daemon (bool, optional): The spawned processes' daemon flag. Default: False.
......@@ -266,8 +264,8 @@ def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options):
such as 6170. Default: None;
(5) selected_gpus (string): The training process will run on the
selected_gpus, such as "0,1,2,3". Default: None;
(6) print_config: Print current parallel training config. Default: False;
(7) use_paddlecloud: Whether to use paddlecloud platform to run your
(6) print_config (bool): Print current parallel training config. Default: False;
(7) use_paddlecloud (bool): Whether to use paddlecloud platform to run your
multi-process job. Default: False.
Returns:
......
......@@ -129,7 +129,7 @@ class GradientClipBase(object):
def __str__(self):
raise NotImplementedError()
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
raise NotImplementedError
......@@ -258,7 +258,7 @@ class GradientClipByValue(GradientClipBase):
def __str__(self):
return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max)
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
......@@ -413,7 +413,7 @@ class GradientClipByNorm(GradientClipBase):
def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
......@@ -565,7 +565,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
@imperative_base.no_grad()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
......
......@@ -69,7 +69,7 @@ class ImperativeQuantAware(object):
from paddle.fluid.contrib.slim.quantization \
import ImperativeQuantAware
from paddle.incubate.hapi.vision.models \
from paddle.vision.models \
import resnet
model = resnet.resnet50(pretrained=True)
......
......@@ -270,7 +270,7 @@ foreach(src ${TEST_OPS})
endforeach()
# setting timeout value for old unittests
if(NOT WIN32)
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_post_training_quantization_mobilenetv1 PROPERTIES TIMEOUT 250 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 200 LABELS "RUN_TYPE=NIGHTLY")
set_tests_properties(test_post_training_quantization_resnet50 PROPERTIES TIMEOUT 200 LABELS "RUN_TYPE=NIGHTLY")
endif()
......@@ -16,10 +16,12 @@ from __future__ import print_function
from __future__ import division
import numpy as np
import math
from .sampler import Sampler, SequenceSampler, RandomSampler
from .dataset import Dataset, IterableDataset
__all__ = ["BatchSampler"]
__all__ = ["BatchSampler", "DistributedBatchSampler"]
class BatchSampler(Sampler):
......@@ -158,3 +160,185 @@ class _InfiniteIterableSampler(object):
def __iter__(self):
while True:
yield [None] * self.batch_size
class DistributedBatchSampler(BatchSampler):
"""Sampler that restricts data loading to a subset of the dataset.
In such case, each process can pass a DistributedBatchSampler instance
as a DataLoader sampler, and load a subset of the original dataset that
is exclusive to it.
.. note::
Dataset is assumed to be of constant size.
Args:
dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
or other python object which implemented
`__len__` for BatchSampler to get sample
number of data source.
batch_size(int): sample indice number in a mini-batch indices.
num_replicas(int, optional): porcess number in distributed training.
If :attr:`num_replicas` is None, :attr:`num_replicas` will be
retrieved from :code:`paddle.fluid.dygraph.parallel.ParallenEnv`.
Default None.
rank(int, optional): the rank of the current process among :attr:`num_replicas`
processes. If :attr:`rank` is None, :attr:`rank` is retrieved from
:code:`paddle.fluid.dygraph.parallel.ParallenEnv`. Default None.
shuffle(bool): whther to shuffle indices order before genrating
batch indices. Default False.
drop_last(bool): whether drop the last incomplete batch dataset size
is not divisible by the batch size. Default False
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for data in sampler:
# do something
break
"""
def __init__(self,
dataset,
batch_size,
num_replicas=None,
rank=None,
shuffle=False,
drop_last=False):
self.dataset = dataset
assert isinstance(batch_size, int) and batch_size > 0, \
"batch_size should be a positive integer"
self.batch_size = batch_size
assert isinstance(shuffle, bool), \
"shuffle should be a boolean value"
self.shuffle = shuffle
assert isinstance(drop_last, bool), \
"drop_last should be a boolean number"
from paddle.fluid.dygraph.parallel import ParallelEnv
if num_replicas is not None:
assert isinstance(num_replicas, int) and num_replicas > 0, \
"num_replicas should be a positive integer"
self.nranks = num_replicas
else:
self.nranks = ParallelEnv().nranks
if rank is not None:
assert isinstance(rank, int) and rank >= 0, \
"rank should be a non-negative integer"
self.local_rank = rank
else:
self.local_rank = ParallelEnv().local_rank
self.drop_last = drop_last
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
self.total_size = self.num_samples * self.nranks
def __iter__(self):
num_samples = len(self.dataset)
indices = np.arange(num_samples).tolist()
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.shuffle:
np.random.RandomState(self.epoch).shuffle(indices)
self.epoch += 1
# subsample
def _get_indices_by_batch_size(indices):
subsampled_indices = []
last_batch_size = self.total_size % (self.batch_size * self.nranks)
assert last_batch_size % self.nranks == 0
last_local_batch_size = last_batch_size // self.nranks
for i in range(self.local_rank * self.batch_size,
len(indices) - last_batch_size,
self.batch_size * self.nranks):
subsampled_indices.extend(indices[i:i + self.batch_size])
indices = indices[len(indices) - last_batch_size:]
subsampled_indices.extend(indices[
self.local_rank * last_local_batch_size:(
self.local_rank + 1) * last_local_batch_size])
return subsampled_indices
if self.nranks > 1:
indices = _get_indices_by_batch_size(indices)
assert len(indices) == self.num_samples
_sample_iter = iter(indices)
batch_indices = []
for idx in _sample_iter:
batch_indices.append(idx)
if len(batch_indices) == self.batch_size:
yield batch_indices
batch_indices = []
if not self.drop_last and len(batch_indices) > 0:
yield batch_indices
def __len__(self):
num_samples = self.num_samples
num_samples += int(not self.drop_last) * (self.batch_size - 1)
return num_samples // self.batch_size
def set_epoch(self, epoch):
"""
Sets the epoch number. When :attr:`shuffle=True`, this number is used
as seeds of random numbers. By default, users may not set this, all
replicas (workers) use a different random ordering for each epoch.
If set same number at each epoch, this sampler will yield the same
ordering at all epoches.
Arguments:
epoch (int): Epoch number.
Examples:
.. code-block:: python
import numpy as np
from paddle.io import Dataset, DistributedBatchSampler
# init with dataset
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
image = np.random.random([784]).astype('float32')
label = np.random.randint(0, 9, (1, )).astype('int64')
return image, label
def __len__(self):
return self.num_samples
dataset = RandomDataset(100)
sampler = DistributedBatchSampler(dataset, batch_size=64)
for epoch in range(10):
sampler.set_epoch(epoch)
"""
self.epoch = epoch
......@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
import inspect
import decorator
import contextlib
import functools
import inspect
import sys
import numpy as np
from paddle.fluid import core
......@@ -26,8 +27,8 @@ import objgraph
from ..data_feeder import convert_dtype
__all__ = [
'no_grad', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph', 'enabled',
'to_variable'
'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
'enabled', 'to_variable'
]
......@@ -167,7 +168,80 @@ def disable_dygraph():
_functional_dygraph_context_manager = None
class no_grad:
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
try:
yield
finally:
tracer._train_mode = mode
else:
yield
def no_grad(func=None):
"""
:api_attr: imperative
Create a context which disables dygraph gradient calculation.
In this mode, the result of every computation will have `stop_gradient=True`.
Also functions as a decorator. (Make sure to instantiate without parenthesis.)
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
# use as generator
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = fluid.Linear(2, 2) # l0.weight.gradient() is None
l1 = fluid.Linear(2, 2)
with fluid.dygraph.no_grad():
# l1.weight.stop_gradient is False
tmp = l1.weight * 2 # tmp.stop_gradient is True
x = fluid.dygraph.to_variable(data)
y = l0(x) + tmp
o = l1(y)
o.backward()
print(tmp.gradient() is None) # True
print(l0.weight.gradient() is None) # False
# use as decorator
@fluid.dygraph.no_grad
def test_layer():
with fluid.dygraph.guard():
inp = np.ones([3, 1024], dtype='float32')
t = fluid.dygraph.base.to_variable(inp)
linear1 = fluid.Linear(1024, 4, bias_attr=False)
linear2 = fluid.Linear(4, 4)
ret = linear1(t)
dy_ret = linear2(ret)
test_layer()
"""
if func is None:
return _switch_tracer_mode_guard_(is_train=False)
else:
@decorator.decorator
def __impl__(func, *args, **kwargs):
with _switch_tracer_mode_guard_(is_train=False):
return func(*args, **kwargs)
return __impl__(func)
class no_grad_:
"""
:api_attr: imperative
......
......@@ -207,6 +207,7 @@ def load_dygraph(model_path, keep_name_table=False):
# NOTE: `jit.save` doesn't save optimizer state
else:
# Load state dict by `save_dygraph` save format
para_dict = {}
if os.path.exists(params_file_path):
with open(params_file_path, 'rb') as f:
para_dict = pickle.load(f) if six.PY2 else pickle.load(
......
......@@ -16,9 +16,7 @@ import astor
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
from paddle.fluid.dygraph.dygraph_to_static import utils
class BasicApiTransformer(gast.NodeTransformer):
......@@ -56,7 +54,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if isinstance(child_node, gast.Call):
# TODO(liym27):
# Considers that a dygraph api which modifies the input or has a output.
if is_dygraph_api(child_node):
if utils.is_dygraph_api(child_node):
return
else:
self._visit_Call(child_node)
......@@ -73,7 +71,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if self._is_dygraph_forward(func_name):
class_node = self._get_class_node(func_name)
static_node = to_static_ast(node, class_node)
static_node = utils.to_static_ast(node, class_node)
return static_node
else:
return node
......@@ -91,14 +89,51 @@ class BasicApiTransformer(gast.NodeTransformer):
if is_to_variable(node_value):
return False
if is_dygraph_api(node_value):
if utils.is_dygraph_api(node_value):
dygraph_api = node_value.func.attr
if not dygraph_class_to_static_api.get(dygraph_api):
if not utils.dygraph_class_to_static_api.get(dygraph_api):
return False
update_args_of_func(node_value, node_value, "__init__")
utils.update_args_of_func(node_value, node_value, "__init__")
target_str = astor.to_source(gast.gast_to_ast(node.targets[0]))
self.class_node_dict[target_str] = node_value
return True
# TODO: node.value is not dygraph class
return False
def is_to_variable(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
if utils.is_dygraph_api(node):
return api_name.endswith("to_variable")
if utils.is_paddle_api(node):
return api_name.endswith("to_tensor")
return False
def to_assign_node(node):
# Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert isinstance(node, gast.Call)
assign_api = gast.parse('fluid.layers.assign').body[0].value
node.func = assign_api
if node.args:
node.args = [node.args[0]]
node.keywords = []
else:
for idx, kw in enumerate(node.keywords):
if kw.arg == 'value' or kw.arg == 'data':
node.keywords[idx].arg = 'input'
node.keywords = [node.keywords[idx]]
node.args = []
break
return node
......@@ -296,7 +296,7 @@ def convert_to_input_spec(inputs, input_spec):
elif isinstance(input_spec, dict):
input_with_spec = {}
check_type_and_len(inputs, input_spec, True)
for name, input in inputs.items():
for name, input in six.iteritems(inputs):
if name in input_spec:
input_with_spec[name] = convert_to_input_spec(input,
input_spec[name])
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import six
import copy
from collections import defaultdict
......@@ -230,7 +231,7 @@ class NameVisitor(gast.NodeVisitor):
return False
def _update_name_ids(self, new_name_ids):
for name_id, ctxs in new_name_ids.items():
for name_id, ctxs in six.iteritems(new_name_ids):
self.name_ids[name_id] = ctxs + self.name_ids[name_id]
......@@ -250,7 +251,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load):
"""
name_ids = [
var_id for var_id, var_ctx in var_ids_dict.items()
var_id for var_id, var_ctx in six.iteritems(var_ids_dict)
if isinstance(var_ctx[0], ctx)
]
if return_ids:
......@@ -341,7 +342,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
def _vars_with_store(ids_dict):
vars = []
for k, ctxs in ids_dict.items():
for k, ctxs in six.iteritems(ids_dict):
if _is_return_var(ctxs):
vars.append(k)
return vars
......@@ -353,7 +354,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
def _vars_loaded_before_store(ids_dict):
new_dict = defaultdict(list)
for k, ctxs in ids_dict.items():
for k, ctxs in six.iteritems(ids_dict):
for ctx in ctxs:
if isinstance(ctx, gast.Load):
new_dict[k].append(ctx)
......
......@@ -98,8 +98,15 @@ class TranslatorLogger(object):
return level == self.transformed_code_level
def has_verbosity(self, level):
"""
Checks whether the verbosity level set by the user is greater than or equal to the log level.
Args:
level(int): The level of log.
Returns:
True if the verbosity level set by the user is greater than or equal to the log level, otherwise False.
"""
level = self.check_level(level)
return level >= self.verbosity_level
return self.verbosity_level >= level
def error(self, msg, *args, **kwargs):
self.logger.error(msg, *args, **kwargs)
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import numpy as np
import logging
import six
from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core
......@@ -334,7 +335,7 @@ class PartialProgramLayer(layers.Layer):
param_and_buffer_names_set.add(var.name)
for block in main_program.blocks:
for name, var in block.vars.items():
for name, var in six.iteritems(block.vars):
if isinstance(var, framework.Parameter):
if name not in param_and_buffer_names_set:
raise ValueError(
......
......@@ -24,6 +24,7 @@ import warnings
import gast
from paddle.fluid import framework
from paddle.fluid import in_dygraph_mode
from paddle.fluid.dygraph import layers
from paddle.fluid.data_feeder import check_type
from paddle.fluid.layers.utils import flatten
......@@ -32,6 +33,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.error import ERROR_DATA
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info
from paddle.fluid.dygraph.dygraph_to_static.origin_info import create_and_update_origin_info_map
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
......@@ -283,13 +285,21 @@ class StaticLayer(object):
Return:
Outputs of decorated function.
"""
# 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_declarative:
warnings.warn(
"The decorator '@paddle.jit.to_static' doesn't work when setting ProgramTranslator.enable=False. "
logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.")
return self._call_dygraph_function(*args, **kwargs)
if not in_dygraph_mode() and self._program_trans.enable_declarative:
raise RuntimeError(
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
"following API: paddle.disable_static().".format(
self.dygraph_function))
# 2. trace ops from dygraph layers and cache the generated program.
args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs)
try:
......@@ -393,19 +403,43 @@ class StaticLayer(object):
def concrete_program(self):
"""
Returns recent ConcreteProgram instance of decorated function.
Examples:
.. code-block:: python
import paddle
from paddle.jit import to_static
from paddle.static import InputSpec
paddle.disable_static()
def foo(x, y):
z = x + y
return z
# usage 1:
decorated_foo = to_static(foo, input_spec=[InputSpec([10], name='x'), InputSpec([10], name='y')])
print(decorated_foo.concrete_program)
# usage 2:
decorated_foo = to_static(foo)
out_foo = decorated_foo(paddle.rand([10]), paddle.rand([10]))
print(decorated_foo.concrete_program)
"""
# if specific the `input_spec`, the length of program_cache will always 1,
# else, return the last one.
cached_program_len = len(self._program_cache)
# If specific `input_spec`, apply convertion from dygraph layers into static Program.
if cached_program_len == 0:
if len(self._function_spec.flat_input_spec) > 0:
input_spec = self._function_spec.input_spec
input_spec = self._function_spec.input_spec
has_input_spec = (input_spec is not None and len(input_spec) > 0)
if has_input_spec:
concrete_program, _ = self.get_concrete_program(*input_spec)
return concrete_program
else:
raise ValueError("No valid transformed program for {}".format(
self._function_spec))
raise ValueError(
"No valid transformed program for {}.\n\t Please specific `input_spec` in `@paddle.jit.to_static` or feed input tensor to call the decorated function at once.\n".
format(self._function_spec))
# If more than one programs have been cached, return the recent converted program by default.
elif cached_program_len > 1:
logging.warning(
......@@ -617,7 +651,7 @@ class ProgramCache(object):
return len(self._caches)
def concrete_programs(self):
return [cp for key, (cp, _) in self._caches.iteritems()]
return [cp for key, (cp, _) in six.iteritems(self._caches)]
def synchronized(func):
......
......@@ -136,9 +136,12 @@ def is_api_in_module(node, module_prefix):
# import_str = "".join(import_statements)
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph as dygraph
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import to_variable
import paddle.fluid.dygraph as dygraph
from paddle import to_tensor
return eval("_is_api_in_module_helper({}, '{}')".format(func_str,
module_prefix))
except NameError:
......@@ -146,15 +149,18 @@ def is_api_in_module(node, module_prefix):
def is_dygraph_api(node):
# Note: A api in module dygraph_to_static is not a real dygraph api.
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"):
return False
# TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only
return is_api_in_module(node, "paddle.fluid.dygraph")
def is_paddle_api(node):
return is_api_in_module(node, "paddle.fluid")
return is_api_in_module(node, "paddle")
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
......@@ -233,14 +239,6 @@ def _add_keywords_to(node, dygraph_api_name):
return
def is_to_variable(node):
assert isinstance(node, gast.Call)
if is_dygraph_api(node):
api_name = ast_to_source_code(node.func).strip()
return api_name.endswith("to_variable")
return False
def to_static_ast(node, class_node):
assert isinstance(node, gast.Call)
assert isinstance(class_node, gast.Call)
......@@ -268,29 +266,6 @@ def to_static_ast(node, class_node):
return node
def to_assign_node(node):
# Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert isinstance(node, gast.Call)
assign_api = gast.parse('fluid.layers.assign').body[0].value
node.func = assign_api
if node.args:
node.args = [node.args[0]]
node.keywords = []
else:
for idx, kw in enumerate(node.keywords):
if kw.arg == 'value':
node.keywords[idx].arg = 'input'
node.keywords = [node.keywords[idx]]
node.args = []
break
return node
def update_args_of_func(node, dygraph_node, method_name):
assert isinstance(node, gast.Call)
if method_name not in ["__init__", "forward"]:
......@@ -493,7 +468,7 @@ def recover_globals_attribute(src_obj, dst_obj):
src_globals = getattr(src_obj, attr_name, {})
dst_globals = getattr(dst_obj, attr_name, {})
for k, v in src_globals.items():
for k, v in six.iteritems(src_globals):
# ignore builtin attribute.
if not (k.startswith('__') and k.endswith('__')):
dst_globals[k] = v
......
......@@ -754,7 +754,7 @@ def save(layer, model_path, input_spec=None, configs=None):
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in layer.state_dict().items():
for structured_name, var in six.iteritems(layer.state_dict()):
state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info
......
......@@ -41,7 +41,7 @@ def monkey_patch_math_varbase():
The difference is, in dygraph mode, use auto-generated op functions for better performance.
"""
@no_grad()
@no_grad
def create_tensor(value, dtype, shape):
out = _varbase_creator(dtype=dtype)
out = core.ops.fill_constant(out, 'dtype', dtype, 'shape', shape,
......
......@@ -349,38 +349,53 @@ class DataParallel(layers.Layer):
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
with fluid.dygraph.guard(place):
# prepare the data parallel context
strategy = fluid.dygraph.prepare_context()
linear = fluid.dygraph.Linear(1, 10, act="softmax")
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, parameter_list=linear.parameters())
# make the module become the data parallelism module
linear = fluid.dygraph.DataParallel(linear, strategy)
x_data = np.random.random(size=[10, 1]).astype(np.float32)
data = fluid.dygraph.to_variable(x_data)
hidden = linear(data)
avg_loss = fluid.layers.mean(hidden)
# scale the loss according to the number of trainers.
avg_loss = linear.scale_loss(avg_loss)
avg_loss.backward()
# collect the gradients of trainers.
linear.apply_collective_grads()
adam.minimize(avg_loss)
linear.clear_gradients()
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
"""
if not self._is_data_parallel_mode():
return loss
......@@ -430,7 +445,7 @@ class DataParallel(layers.Layer):
self._reshape_inplace(x=g_var, shape=g_shape)
assert g_var.shape == g_shape
@no_grad()
@no_grad
def apply_collective_grads(self):
"""
AllReduce the Parameters' gradient.
......@@ -438,38 +453,53 @@ class DataParallel(layers.Layer):
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
place = fluid.CUDAPlace(fluid.dygraph.ParallelEnv().dev_id)
with fluid.dygraph.guard(place):
# prepare the data parallel context
strategy = fluid.dygraph.prepare_context()
linear = fluid.dygraph.Linear(1, 10, act="softmax")
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, parameter_list=linear.parameters())
# make the module become the data parallelism module
linear = fluid.dygraph.DataParallel(linear, strategy)
x_data = np.random.random(size=[10, 1]).astype(np.float32)
data = fluid.dygraph.to_variable(x_data)
hidden = linear(data)
avg_loss = fluid.layers.mean(hidden)
# scale the loss according to the number of trainers.
avg_loss = linear.scale_loss(avg_loss)
avg_loss.backward()
# collect the gradients of trainers.
linear.apply_collective_grads()
adam.minimize(avg_loss)
linear.clear_gradients()
import paddle
import paddle.nn as nn
import paddle.optimizer as opt
import paddle.distributed as dist
class LinearNet(nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear1 = nn.Linear(10, 10)
self._linear2 = nn.Linear(10, 1)
def forward(self, x):
return self._linear2(self._linear1(x))
def train():
# 1. enable dynamic mode
paddle.disable_static()
# 2. initialize parallel environment
dist.init_parallel_env()
# 3. create data parallel layer & optimizer
layer = LinearNet()
dp_layer = paddle.DataParallel(layer)
loss_fn = nn.MSELoss()
adam = opt.Adam(
learning_rate=0.001, parameters=dp_layer.parameters())
# 4. run layer
inputs = paddle.randn([10, 10], 'float32')
outputs = dp_layer(inputs)
labels = paddle.randn([10, 1], 'float32')
loss = loss_fn(outputs, labels)
loss = dp_layer.scale_loss(loss)
loss.backward()
dp_layer.apply_collective_grads()
adam.step()
adam.clear_grad()
if __name__ == '__main__':
# 1. start by ``paddle.distributed.spawn`` (default)
dist.spawn(train, nprocs=2)
# 2. start by ``paddle.distributed.launch``
# train()
"""
if not self._is_data_parallel_mode():
return
......
......@@ -42,6 +42,9 @@ op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
SPARSE_OP_LIST = ["lookup_table", "lookup_table_v2"]
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
def _get_lr_ops(program):
lr_ops = []
......@@ -66,7 +69,7 @@ def _has_global_step(lr_ops):
def is_sparse_op(op):
if op.type == "lookup_table" and op.attr('is_sparse') is True and op.attr(
if op.type in SPARSE_OP_LIST and op.attr('is_sparse') is True and op.attr(
'is_distributed') is False:
return True
......@@ -78,7 +81,7 @@ def is_sparse_op(op):
def is_distributed_sparse_op(op):
if op.type == "lookup_table" and op.attr('is_distributed') is True:
if op.type in SPARSE_OP_LIST and op.attr('is_distributed') is True:
return True
if op.type == "distributed_lookup_table" and op.attr(
......@@ -802,11 +805,10 @@ class CompileTimeStrategy(object):
def _get_sparse_varnames():
varnames = []
op_types = {"lookup_table": "W"}
for op in origin_program.global_block().ops:
if op.type in op_types.keys() \
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(op_types[op.type])[0]
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
varnames.append(param_name)
return list(set(varnames))
......
......@@ -40,6 +40,8 @@ LR_SCHED_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.LRSched
OPT_OP_ROLE_ATTR_VALUE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
SPARSE_OP_TYPE_DICT = {"lookup_table": "W", "lookup_table_v2": "W"}
DEVICE_LIST = ["cpu", "gpu", "xpu"]
COMMUNICATE_OPS_TYPE = ["send", "recv", "fetch_barrier", "send_barrier"]
DEFAULT_DEVICE = 'cpu'
......@@ -81,11 +83,10 @@ def distributed_ops_pass(program, config):
def _get_pull_sparse_ops(_program):
pull_sparse_ops = {}
op_types = {"lookup_table": "W"}
for op in _program.global_block().ops:
if op.type in op_types.keys() \
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(op_types[op.type])[0]
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
ops = pull_sparse_ops.get(param_name, [])
ops.append(op)
pull_sparse_ops[param_name] = ops
......@@ -101,6 +102,7 @@ def distributed_ops_pass(program, config):
w = program.global_block().vars[ops[0].input("W")[0]]
padding_idx = ops[0].attr("padding_idx")
is_distributed = ops[0].attr("is_distributed")
op_type = ops[0].type
outputs = [
program.global_block().vars[op.output("Out")[0]] for op in ops
......@@ -149,7 +151,8 @@ def distributed_ops_pass(program, config):
"is_distributed": is_distributed,
"pserver_num": len(pserver_endpoints),
"padding_idx": padding_idx,
"trainer_id": trainer_id
"trainer_id": trainer_id,
"lookup_table_version": op_type
})
else:
raise ValueError(
......
......@@ -348,6 +348,41 @@ class PSLib(Fleet):
self._fleet_ptr.save_model(dirname, mode)
self._role_maker._barrier_worker()
def save_model_with_whitelist(self,
executor,
dirname,
whitelist_path,
main_program=None,
**kwargs):
"""
save whitelist, mode is consistent with fleet.save_persistables,
when using fleet, it will save sparse and dense feature
Args:
executor(Executor): fluid executor
dirname(str): save path. It can be hdfs/afs path or local path
main_program(Program): fluid program, default None
kwargs: use define property, current support following
mode(int): 0 means save all pserver model,
1 means save delta pserver model (save diff),
2 means save xbox base,
3 means save batch model.
Example:
.. code-block:: python
fleet.save_persistables(dirname="/you/path/to/model", mode = 0)
"""
mode = kwargs.get("mode", 0)
table_id = kwargs.get("table_id", 0)
self._fleet_ptr.client_flush()
self._role_maker._barrier_worker()
if self._role_maker.is_first_worker():
self._fleet_ptr.save_model_with_whitelist(table_id, dirname, mode,
whitelist_path)
self._role_maker._barrier_worker()
def save_cache_model(self, executor, dirname, main_program=None, **kwargs):
"""
save sparse cache table,
......@@ -480,6 +515,51 @@ class PSLib(Fleet):
self._fleet_ptr.clear_model()
self._role_maker._barrier_worker()
def load_pslib_whitelist(self, table_id, model_path, **kwargs):
"""
load pslib model for one table with whitelist
Args:
table_id(int): load table id
model_path(str): load model path, can be local or hdfs/afs path
kwargs(dict): user defined params, currently support following:
only for load pslib model for one table:
mode(int): load model mode. 0 is for load whole model, 1 is
for load delta model (load diff), default is 0.
only for load params from paddle model:
scope(Scope): Scope object
model_proto_file(str): path of program desc proto binary
file, can be local or hdfs/afs file
var_names(list): var name list
load_combine(bool): load from a file or split param files
default False.
Examples:
.. code-block:: python
# load pslib model for one table
fleet.load_one_table(0, "hdfs:/my_fleet_model/20190714/0/")
fleet.load_one_table(1, "hdfs:/xx/xxx", mode = 0)
# load params from paddle model
fleet.load_one_table(2, "hdfs:/my_paddle_model/",
scope = my_scope,
model_proto_file = "./my_program.bin",
load_combine = False)
# below is how to save proto binary file
with open("my_program.bin", "wb") as fout:
my_program = fluid.default_main_program()
fout.write(my_program.desc.serialize_to_string())
"""
self._role_maker._barrier_worker()
mode = kwargs.get("mode", 0)
if self._role_maker.is_first_worker():
self._fleet_ptr.load_table_with_whitelist(table_id, model_path,
mode)
self._role_maker._barrier_worker()
def load_one_table(self, table_id, model_path, **kwargs):
"""
load pslib model for one table or load params from paddle model
......
......@@ -129,6 +129,7 @@ def one_hot(input, depth, allow_out_of_range=False):
return one_hot_out
@deprecated(since='2.0.0', update_to='paddle.nn.functional.embedding')
def embedding(input,
size,
is_sparse=False,
......
......@@ -147,8 +147,10 @@ class LayerHelper(LayerHelperBase):
if 'use_cudnn' in self.kwargs and self.kwargs.get('use_cudnn'):
act['use_cudnn'] = self.kwargs.get('use_cudnn')
if 'use_mkldnn' in self.kwargs:
act['use_mkldnn'] = self.kwargs.get('use_mkldnn')
use_mkldnn = self.kwargs.get(
'use_mkldnn', core.globals().get("FLAGS_use_mkldnn", False))
if use_mkldnn:
act['use_mkldnn'] = use_mkldnn
act_type = act.pop('type')
tmp = self.create_variable_for_type_inference(dtype=input_var.dtype)
......
......@@ -367,6 +367,7 @@ def fc(input,
return helper.append_activation(pre_activation)
@deprecated(since="2.0.0", update_to="paddle.nn.functional.embedding")
def embedding(input,
size,
is_sparse=False,
......@@ -15013,6 +15014,7 @@ def gather_tree(ids, parents):
return out
@deprecated(since="2.0.0", update_to="paddle.uniform")
@templatedoc()
def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
name=None):
......
......@@ -62,7 +62,7 @@ class Optimizer(object):
but need to use one of it's implementation.
"""
@imperative_base.no_grad()
@imperative_base.no_grad
def __init__(self,
learning_rate,
parameter_list=None,
......@@ -898,7 +898,7 @@ class Optimizer(object):
if p.trainable:
p.clear_gradient()
@imperative_base.no_grad()
@imperative_base.no_grad
def minimize(self,
loss,
startup_program=None,
......@@ -1016,7 +1016,7 @@ class SGDOptimizer(Optimizer):
name=name)
self.type = "sgd"
@no_grad()
@no_grad
def _append_optimize_op(self, block, param_and_grad):
lr = self._create_param_lr(param_and_grad)
if framework.in_dygraph_mode():
......@@ -1553,7 +1553,7 @@ class DGCMomentumOptimizer(Optimizer):
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name])
@imperative_base.no_grad()
@imperative_base.no_grad
def apply_gradients(self, params_grads):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tarfile
import paddle.fluid as fluid
import paddle
from paddle.fluid import core
URL = 'http://paddle-unittest-data.gz.bcebos.com/python_paddle_fluid_tests_demo_async-executor/train_data.tar.gz'
MD5 = '2a405a31508969b3ab823f42c0f522ca'
def bow_net(data,
label,
dict_dim=89528,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
models/fluid/PaddleNLP/text_classification/nets.py
"""
# embedding
emb = fluid.layers.embedding(
input=data, size=[dict_dim, emb_dim], is_sparse=True)
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bowh = fluid.layers.tanh(bow)
# fc layer after conv
fc_1 = fluid.layers.fc(input=bowh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
# probability of each class
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
# cross entropy loss
cost = fluid.layers.cross_entropy(input=prediction, label=label)
# mean loss
avg_cost = fluid.layers.mean(x=cost)
acc = fluid.layers.accuracy(input=prediction, label=label)
return avg_cost, acc, prediction
def train():
# Download data
with tarfile.open(paddle.dataset.common.download(URL, "imdb", MD5)) as tarf:
tarf.extractall(path='./')
tarf.close()
# Initialize dataset description
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(128) # See API doc for how to change other fields
# define network
# input text data
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
# label data
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
dataset.set_use_var([data, label])
avg_cost, acc, prediction = bow_net(data, label)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002)
opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost)
# Run startup program
startup_program = fluid.default_startup_program()
place = fluid.CPUPlace()
executor = fluid.Executor(place)
executor.run(startup_program)
main_program = fluid.default_main_program()
epochs = 10
filelist = ["train_data/part-%d" % i for i in range(12)]
dataset.set_filelist(filelist)
for i in range(epochs):
dataset.set_thread(4)
executor.train_from_dataset(
main_program, # This can be changed during iteration
dataset, # This can be changed during iteration
debug=False)
fluid.io.save_inference_model('imdb/epoch%d.model' % i,
[data.name, label.name], [acc], executor)
if __name__ == "__main__":
train()
此差异已折叠。
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy
import six
import paddle
import paddle.dataset.mnist as mnist
import paddle.fluid as fluid
def network(is_train):
reader = fluid.layers.py_reader(
capacity=10,
shapes=((-1, 784), (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
img, label = fluid.layers.read_file(reader)
hidden = img
for i in six.moves.xrange(2):
hidden = fluid.layers.fc(input=hidden, size=100, act='tanh')
hidden = fluid.layers.dropout(
hidden, dropout_prob=0.5, is_test=not is_train)
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
return fluid.layers.mean(loss), reader
def main():
train_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
loss, train_reader = network(True)
adam = fluid.optimizer.Adam(learning_rate=0.01)
adam.minimize(loss)
test_prog = fluid.Program()
test_startup = fluid.Program()
with fluid.program_guard(test_prog, test_startup):
with fluid.unique_name.guard():
test_loss, test_reader = network(False)
use_cuda = fluid.core.is_compiled_with_cuda()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
fluid.Executor(place).run(startup_prog)
fluid.Executor(place).run(test_startup)
trainer = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name, main_program=train_prog)
tester = fluid.ParallelExecutor(
use_cuda=use_cuda, share_vars_from=trainer, main_program=test_prog)
train_reader.decorate_paddle_reader(
paddle.reader.shuffle(
paddle.batch(mnist.train(), 512), buf_size=8192))
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))
for epoch_id in six.moves.xrange(10):
train_reader.start()
try:
while True:
print(
'train_loss',
numpy.array(trainer.run(fetch_list=[loss.name])))
except fluid.core.EOFException:
print('End of epoch', epoch_id)
train_reader.reset()
test_reader.start()
try:
while True:
print(
'test loss',
numpy.array(tester.run(fetch_list=[test_loss.name])))
except fluid.core.EOFException:
print('End of testing')
test_reader.reset()
if __name__ == '__main__':
main()
......@@ -432,8 +432,6 @@ if(WITH_DISTRIBUTE)
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_lars")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_mnist_train")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_save_load")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_simnet_bow")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_ctr")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_text_classification")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_train")
list(REMOVE_ITEM DIST_TEST_OPS "test_dist_word2vec")
......@@ -587,8 +585,10 @@ endif()
# setting timeout value for old unittests
# set_tests_properties(test_dist_fleet_sparse_embedding_ctr PROPERTIES TIMEOUT 200)
set_tests_properties(test_fused_elemwise_activation_op PROPERTIES TIMEOUT 150)
set_tests_properties(test_gru_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150)
set_tests_properties(test_regularizer PROPERTIES TIMEOUT 150)
if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_fused_elemwise_activation_op PROPERTIES TIMEOUT 150)
set_tests_properties(test_gru_op PROPERTIES TIMEOUT 200)
set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 150)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150)
set_tests_properties(test_regularizer PROPERTIES TIMEOUT 150)
endif()
......@@ -196,8 +196,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fleet.stop_worker()
def do_dataset_training(self, fleet):
dnn_input_dim, lr_input_dim, train_file_path = ctr_dataset_reader.prepare_data(
)
train_file_list = ctr_dataset_reader.prepare_fake_data()
exe = fluid.Executor(fluid.CPUPlace())
......@@ -206,9 +205,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
thread_num = 2
batch_size = 128
filelist = []
for _ in range(thread_num):
filelist.append(train_file_path)
filelist = train_file_list
# config dataset
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset()
......
......@@ -123,7 +123,7 @@ class TestConvertWithCache(unittest.TestCase):
@declarative
def sum_even_util_limit(max_len, limit):
def sum_even_until_limit(max_len, limit):
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
for i in range(max_len):
if i % 2 > 0:
......@@ -147,7 +147,7 @@ def sum_under_while(limit):
class TestToOutputWithCache(unittest.TestCase):
def test_output(self):
with fluid.dygraph.guard():
ret = sum_even_util_limit(80, 10)
ret = sum_even_until_limit(80, 10)
self.assertEqual(ret.numpy(), 30)
ret = declarative(sum_under_while)(100)
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册