提交 71f6ba83 编写于 作者: S sandyhouse

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_timeline

......@@ -34,7 +34,7 @@ if (NOT LITE_SOURCE_DIR OR NOT LITE_BINARY_DIR)
set(LITE_INSTALL_DIR ${THIRD_PARTY_PATH}/install/lite)
if(NOT LITE_GIT_TAG)
set(LITE_GIT_TAG 42ab4d559f6659edfc35040fb30fdcec3dc3f8aa)
set(LITE_GIT_TAG dfdfa6440c83bf0b415f9f5a9ff84842ce0bb0fa)
endif()
if(NOT CUDA_ARCH_NAME)
......
......@@ -154,10 +154,17 @@ func (config *AnalysisConfig) EnableMkldnnQuantizer() {
C.PD_EnableMkldnnQuantizer(config.c)
}
func (config *AnalysisConfig) EnableMkldnnBfloat16() {
C.PD_EnableMkldnnBfloat16(config.c)
}
func (config *AnalysisConfig) MkldnnQuantizerEnabled() bool {
return ConvertCBooleanToGo(C.PD_MkldnnQuantizerEnabled(config.c))
}
func (config *AnalysisConfig) MkldnnBfloat16Enabled() bool {
return ConvertCBooleanToGo(C.PD_MkldnnBfloat16Enabled(config.c))
}
// SetModelBuffer
// ModelFromMemory
......
......@@ -119,7 +119,7 @@ cc_test(data_layout_transform_test SRCS data_layout_transform_test.cc DEPS data_
cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
framework_proto selected_rows data_device_transform data_type_transform data_layout_transform)
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost)
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost enforce)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
device_context)
......
......@@ -115,6 +115,7 @@ message VarType {
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
BF16 = 22;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
......
......@@ -12,67 +12,122 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include <glog/logging.h>
#include <deque>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/generator.h"
namespace paddle {
namespace framework {
std::shared_ptr<Generator> Generator::gen_instance_ = NULL;
const std::shared_ptr<Generator>& DefaultCPUGenerator() {
static auto default_cpu_generator =
std::make_shared<Generator>(GetRandomSeed());
VLOG(4) << "initial seed: " << default_cpu_generator->GetCurrentSeed()
<< ", cpu engine: " << default_cpu_generator->GetCPUEngine().get();
return default_cpu_generator;
}
std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine() {
static auto op_default_cpu_engine = std::make_shared<std::mt19937_64>();
return op_default_cpu_engine;
}
// NOTE(zhiqiu): there are 3 conditions:
// (1) op seed is not set and DefaultCPUGenerator is inited, use
// DefaultCPUGenerator
// (2) op seed is not set and DefaultCPUGenerator is not inited, use se
// OpDefaultCPUEngine() and set a radnom seed
// (3) op seed is set, use OpDefaultCPUEngine() and set the seed
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t seed) {
if (DefaultCPUGenerator()->GetIsInitPy() && seed == 0) {
VLOG(4) << "Use random engine from generator";
return DefaultCPUGenerator()->GetCPUEngine();
} else {
// NOTE(zhiqiu): creating an engine instance everytime instead of using
// OpDefaultCPUEngine(), this is the legacy behavior of random operators.
// The benefit is that when runing PE with fixed-seed in multiple thrads,
// each thread has their own engine, and doesn't affect each other.
//
// And we need to measure the determinacy of Generator in PE.
auto engine = std::make_shared<std::mt19937_64>();
if (seed == 0) {
seed = GetRandomSeed();
VLOG(4) << "Use default random engine with random seed = " << seed;
} else {
VLOG(4) << "Use default random engine with fixed random seed = " << seed;
}
static std::mutex mu_;
{
std::lock_guard<std::mutex> lock(mu_);
engine->seed(seed);
}
return engine;
}
}
GeneratorState* Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_.get();
GeneratorState Generator::GetState() {
std::lock_guard<std::mutex> lock(this->mu_);
state_.cpu_engine = *engine_;
return this->state_;
}
void Generator::SetState(GeneratorState* state_in) {
std::lock_guard<std::mutex> lock(this->mutex);
*this->state_ = *state_in;
void Generator::SetState(const GeneratorState& state) {
std::lock_guard<std::mutex> lock(this->mu_);
this->state_ = state;
this->engine_ = std::make_shared<std::mt19937_64>(state.cpu_engine);
}
uint64_t Generator::GetCurrentSeed() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->current_seed;
std::lock_guard<std::mutex> lock(this->mu_);
return this->state_.current_seed;
}
uint64_t Generator::Seed() {
std::lock_guard<std::mutex> lock(this->mutex);
std::lock_guard<std::mutex> lock(this->mu_);
uint64_t seed;
std::random_device de;
seed = ((((uint64_t)de()) << 32) + de()) & 0x1FFFFFFFFFFFFF;
this->state_->current_seed = seed;
this->state_.current_seed = seed;
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);
this->engine_->seed(seq);
return this->state_->current_seed;
return this->state_.current_seed;
}
void Generator::SetCurrentSeed(uint64_t seed) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->current_seed = uint64_t(seed);
std::lock_guard<std::mutex> lock(this->mu_);
this->state_.current_seed = seed;
std::seed_seq seq({seed});
this->state_->cpu_engine.seed(seq);
this->engine_->seed(seq);
}
std::mt19937_64& Generator::GetCPUEngine() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine;
std::shared_ptr<std::mt19937_64> Generator::GetCPUEngine() {
std::lock_guard<std::mutex> lock(this->mu_);
return this->engine_;
}
void Generator::SetCPUEngine(std::mt19937_64 engine) {
std::lock_guard<std::mutex> lock(this->mutex);
this->state_->cpu_engine = std::mt19937_64(engine);
void Generator::SetCPUEngine(std::shared_ptr<std::mt19937_64> engine) {
std::lock_guard<std::mutex> lock(this->mu_);
this->engine_ = engine;
}
uint64_t Generator::Random64() {
std::lock_guard<std::mutex> lock(this->mutex);
return this->state_->cpu_engine();
std::lock_guard<std::mutex> lock(this->mu_);
auto engine = this->engine_;
return (*engine)();
}
void Generator::SetIsInitPy(bool is_init_py) {
this->is_init_py_ = is_init_py;
VLOG(4) << "SetIsInitPy:" << this->is_init_py_;
}
bool Generator::GetIsInitPy() const { return this->is_init_py_; }
} // namespace framework
} // namespace paddle
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <stdint.h>
#include <atomic>
#include <deque>
#include <iostream> // temp for debug
......@@ -27,6 +29,12 @@ limitations under the License. */
namespace paddle {
namespace framework {
static uint64_t GetRandomSeed() {
std::random_device rd;
// double has 53 bit significant, so limit uint64 to 53 bits
return ((((uint64_t)rd()) << 32) + rd()) & 0x1FFFFFFFFFFFFF;
}
struct GeneratorState {
int64_t device = -1;
uint64_t current_seed = 34342423252;
......@@ -35,62 +43,67 @@ struct GeneratorState {
struct Generator {
Generator() {
GeneratorState default_gen_state_cpu;
default_gen_state_cpu.device = -1;
default_gen_state_cpu.current_seed = 34342423252;
std::seed_seq seq({34342423252});
default_gen_state_cpu.cpu_engine = std::mt19937_64(seq);
this->state_ = std::make_shared<GeneratorState>(default_gen_state_cpu);
auto seed = GetRandomSeed();
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
}
explicit Generator(uint64_t seed) {
std::seed_seq seq({seed});
auto engine = std::make_shared<std::mt19937_64>(seq);
this->state_.cpu_engine = *engine;
this->state_.device = -1;
this->state_.current_seed = seed;
this->engine_ = engine;
VLOG(4) << "initial seed: " << this->state_.current_seed
<< ", cpu engine: " << &this->state_.cpu_engine;
this->is_init_py_ = true; // TODO(zhiqiu): remove it in future
}
explicit Generator(GeneratorState state_in)
: state_{std::make_shared<GeneratorState>(state_in)} {}
Generator(const Generator& other)
: Generator(other, std::lock_guard<std::mutex>(other.mutex)) {}
Generator(const Generator& other) = delete;
// get random state
GeneratorState* GetState();
GeneratorState GetState();
// set random state
void SetState(GeneratorState* state_in);
void SetState(const GeneratorState&);
// get current seed
uint64_t GetCurrentSeed();
// random a seed and get
uint64_t Seed();
// set seed
void SetCurrentSeed(uint64_t seed);
// get cpu engine
std::mt19937_64& GetCPUEngine();
std::shared_ptr<std::mt19937_64> GetCPUEngine();
// set cpu engine
void SetCPUEngine(std::mt19937_64 engine);
void SetCPUEngine(std::shared_ptr<std::mt19937_64>);
uint64_t Random64();
bool is_init_py = false;
void SetIsInitPy(bool);
bool GetIsInitPy() const;
// CPU Generator singleton
static std::shared_ptr<Generator> GetInstance() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
return gen_instance_;
}
private:
GeneratorState state_;
std::shared_ptr<std::mt19937_64> engine_;
mutable std::mutex mu_;
// NOTE(zhiqiu): is_init_py_ is used to make generator be compatible with
// old seed, and it should be removed after all random-related operators
// and unittests upgrades to use generator.
bool is_init_py_ = false;
};
static std::shared_ptr<Generator> GetInstanceX() {
if (NULL == gen_instance_) {
gen_instance_.reset(new paddle::framework::Generator());
}
gen_instance_->is_init_py = true;
return gen_instance_;
}
// The DefaultCPUGenerator is used in manual_seed()
const std::shared_ptr<Generator>& DefaultCPUGenerator();
private:
static std::shared_ptr<Generator> gen_instance_;
std::shared_ptr<GeneratorState> state_;
mutable std::mutex mutex;
// If op seed is set or global is not set, the OpDefaultCPUEngine is used.
std::shared_ptr<std::mt19937_64> OpDefaultCPUEngine();
Generator(const Generator& other, const std::lock_guard<std::mutex>&)
: state_(std::make_shared<GeneratorState>(*(other.state_))) {}
};
std::shared_ptr<std::mt19937_64> GetCPURandomEngine(uint64_t);
} // namespace framework
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h"
#include <cmath>
#include <functional>
#include <string>
#include <vector>
......@@ -74,12 +75,17 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims();
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
auto* weights_data = weights->mutable_data<float>(platform::CPUPlace());
EigenMatrixArrayMap weights_array_2d(
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]);
EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
weights_shape_2d[1]);
weights_array_2d.colwise() *= scale_array;
// Check for subnormal values that slows down convolution execution
for (int i = 0; i < weights->numel(); ++i) {
if (std::fpclassify(weights_data[i]) == FP_SUBNORMAL) weights_data[i] = 0;
}
}
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
......@@ -108,13 +114,6 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
GET_CONV_BN_NODES(conv_ac_pattern);
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *affine_channel);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+affinechannel fuse";
return;
}
// Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
......@@ -143,6 +142,7 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
desc.SetType("elementwise_add");
desc.SetAttr("axis", 1);
desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists<bool>("use_mkldnn"));
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
......
......@@ -1879,6 +1879,19 @@ PDNode *patterns::MultipleQuantize::operator()() {
return prev_out;
}
PDNode *patterns::QuantizePlacement::operator()(
const std::unordered_set<std::string> &quantize_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "elementwise_add",
"fc", "matmul", "pool2d", "prior_box",
"relu", "reshape2", "transpose2"});
if (!quantize_enabled_op_types.empty()) {
supported_op_types = quantize_enabled_op_types;
}
auto *op = pattern->NewNode(op_repr())->assert_is_ops(supported_op_types);
return op;
}
PDNode *patterns::MKLDNNInPlace::operator()() {
const std::unordered_set<std::string> &supported_op_types = {
"abs",
......
......@@ -1120,6 +1120,15 @@ struct MultipleQuantize : public PatternBase {
PATTERN_DECL_NODE(prev_out);
};
struct QuantizePlacement : public PatternBase {
QuantizePlacement(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quantize_placement") {}
PDNode* operator()(
const std::unordered_set<std::string>& quantize_enabled_op_types);
PATTERN_DECL_NODE(op);
};
// Pattern used for enforcing inplace computation for in-place computation
// supporting DNNL ops. softmax, batch_norm and layer_norm
struct MKLDNNInPlace : public PatternBase {
......
......@@ -26,27 +26,33 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list =
Get<std::unordered_set<std::string>>("quantize_enabled_op_types");
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(),
n->id()) != excluded_ids_list.end())
continue;
auto* op = n->Op();
if (op->HasAttr("mkldnn_data_type") ||
op->HasProtoAttr("mkldnn_data_type")) {
// use_quantizer is no longer used
// assign value for compatibility
if (op->GetAttrIfExists<bool>("use_quantizer")) {
op->SetAttr("mkldnn_data_type", std::string("int8"));
}
if (std::find(op_types_list.begin(), op_types_list.end(), op->Type()) !=
op_types_list.end()) {
op->SetAttr("mkldnn_data_type", std::string("int8"));
op->SetAttr("use_quantizer", true);
}
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::QuantizePlacement quantize_placement_pattern{gpd.mutable_pattern(),
"quantize_placement"};
quantize_placement_pattern(op_types_list);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(op, op, quantize_placement_pattern);
if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(),
op->id()) != excluded_ids_list.end()) {
return;
}
if (op->Op()->HasAttr("mkldnn_data_type") ||
op->Op()->HasProtoAttr("mkldnn_data_type")) {
// use_quantizer is no longer used
// assign value for compatibility
if (op->Op()->GetAttrIfExists<bool>("use_quantizer")) {
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
}
op->Op()->SetAttr("mkldnn_data_type", std::string("int8"));
op->Op()->SetAttr("use_quantizer", true);
}
}
};
gpd(graph, handler);
}
} // namespace ir
......@@ -58,10 +64,7 @@ REGISTER_PASS(cpu_quantize_placement_pass,
// a vector of operator type names to be quantized ("conv2d" etc.)
// the second param is the default value for this vector
.DefaultPassAttr("quantize_enabled_op_types",
new std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "fc", "matmul",
"pool2d", "prior_box", "relu", "reshape2",
"transpose2"}))
new std::unordered_set<std::string>())
// a vector of operator ids that are to be excluded from quantization
// the second param is the default value for this vector
.DefaultPassAttr("quantize_excluded_op_ids", new std::unordered_set<int>());
......@@ -15,7 +15,10 @@ limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/pass.h"
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
......@@ -23,9 +26,10 @@ namespace ir {
/*
* Specifies which operators should be quantized.
*/
class CPUQuantizePlacementPass : public Pass {
class CPUQuantizePlacementPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"cpu_quantize_placement_pass"};
};
} // namespace ir
......
......@@ -131,8 +131,8 @@ TEST(QuantizerPlacementPass, enabled_conv_excluded_one) {
}
TEST(QuantizerPlacementPass, empty_list) {
// no operator quantized
MainTest({}, {}, 0);
// all operators quantized
MainTest({}, {}, 6);
}
TEST(QuantizerPlacementPass, default_attr_value) {
......
......@@ -81,7 +81,8 @@ void DeleteQuant(ir::Graph* graph, Scope* scope,
if (quantized_op_type == "conv2d" ||
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "fc") {
quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul") {
op_desc->SetAttr("X_scale", scale_value);
......@@ -111,7 +112,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
std::string input_name = "";
if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_fusion") {
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "conv2d_transpose") {
weight_name = "Filter";
input_name = "Input";
} else if (quantized_op_type == "mul") {
......@@ -122,7 +124,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
input_name = "Input";
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, fc, mul for "
"QuantDequantFuse: We only support conv2d, conv2d_fusion, "
"conv2d_transpose, fc, mul for "
"now."));
}
const std::string pattern_name = "dequant_fuse";
......@@ -192,10 +195,12 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
scope->Var(quantized_op_weight_node->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv, weight scale size = weight dims[0]
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
bool valid_scale_size =
(weight_scale.size() == 1 ||
weight_scale.size() == static_cast<size_t>(w_dims[0]));
weight_scale.size() == static_cast<size_t>(w_dims[0]) ||
weight_scale.size() == static_cast<size_t>(w_dims[1]));
PADDLE_ENFORCE_EQ(
valid_scale_size, true,
platform::errors::InvalidArgument(
......@@ -206,8 +211,14 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
if (weight_scale.size() == 1) {
quantized_weight_data[j] *= weight_scale[0];
} else {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
if (quantized_op_type == "conv2d_transpose") {
int inner_size = w_dims[2] * w_dims[3];
quantized_weight_data[j] *=
weight_scale[(j / inner_size) % w_dims[1]];
} else {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
quantized_weight_data[j] *= weight_scale[j / inner_size];
}
}
}
......@@ -220,7 +231,8 @@ void FuseDequant(ir::Graph* graph, Scope* scope,
new_op_desc.SetType(quantized_op_type);
new_op_desc.SetAttr("enable_int8", true);
if (quantized_op_type == "conv2d" || quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d") {
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "conv2d_transpose") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Output", {new_output});
} else if (quantized_op_type == "fc") {
......@@ -253,7 +265,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d", "fc"};
"conv2d", "mul", "depthwise_conv2d", "fc", "conv2d_transpose"};
auto* scope = param_scope();
for (auto& quant_type : quant_types) {
......
......@@ -34,7 +34,8 @@ struct OpUpdateRecord {
kModifyAttr,
kNewAttr,
kNewInput,
kNewOutput
kNewOutput,
kBugfixWithBehaviorChanged,
};
Type type_;
std::string remark_;
......@@ -82,6 +83,11 @@ struct NewOutput : OpUpdateRecord {
std::string name_;
};
struct BugfixWithBehaviorChanged : OpUpdateRecord {
explicit BugfixWithBehaviorChanged(const std::string& remark)
: OpUpdateRecord({Type::kBugfixWithBehaviorChanged, remark}) {}
};
class OpVersionDesc {
public:
OpVersionDesc& ModifyAttr(const std::string& name, const std::string& remark,
......@@ -110,6 +116,12 @@ class OpVersionDesc {
return *this;
}
OpVersionDesc& BugfixWithBehaviorChanged(const std::string& remark) {
infos_.push_back(std::shared_ptr<OpUpdateRecord>(
new compatible::BugfixWithBehaviorChanged(remark)));
return *this;
}
private:
std::vector<std::shared_ptr<OpUpdateRecord>> infos_;
};
......
......@@ -23,6 +23,10 @@ namespace compatible {
TEST(test_operator_version, test_operator_version) {
REGISTER_OP_VERSION(test__)
.AddCheckpoint(
R"ROC(Fix the bug of reshape op, support the case of axis < 0)ROC",
framework::compatible::OpVersionDesc().BugfixWithBehaviorChanged(
"Support the case of axis < 0"))
.AddCheckpoint(
R"ROC(
Upgrade reshape, modified one attribute [axis] and add a new attribute [size].
......
......@@ -913,10 +913,20 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
auto element_num = tensor.numel();
os << " - data: [";
if (element_num > 0) {
os << inspect[0];
for (int j = 1; j < element_num; ++j) {
os << " " << inspect[j];
// Note: int8_t && uint8_t is typedf of char, ostream unable to print properly
if (typeid(int8_t) == typeid(T) || typeid(uint8_t) == typeid(T)) {
if (element_num > 0) {
os << signed(inspect[0]);
for (int j = 1; j < element_num; ++j) {
os << " " << signed(inspect[j]);
}
}
} else {
if (element_num > 0) {
os << inspect[0];
for (int j = 1; j < element_num; ++j) {
os << " " << inspect[j];
}
}
}
os << "]";
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Created by Jiabin on 2019-04-25.
//
#pragma once
namespace paddle {
namespace imperative {
namespace detail {
struct BackwardStrategy {
/* DyGraph now support two kinds of backward strategy, one is sorted sum
* gradient, another is sum gradient once they are created */
// TODO(jiabin): add more Strategy when we support
bool sorted_sum_gradient_{false};
};
} // namespace detail
} // namespace imperative
} // namespace paddle
......@@ -30,12 +30,12 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/profiler.h"
DECLARE_bool(sort_sum_gradient);
namespace paddle {
namespace imperative {
void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph) {
backward_strategy_ = strategy;
void BasicEngine::Init(VarBase* var, bool retain_graph) {
retain_graph_ = retain_graph;
init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode();
......@@ -105,7 +105,7 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {
auto& accumulator = accumulators_[var.get()];
if (!accumulator) {
if (backward_strategy_.sorted_sum_gradient_) {
if (FLAGS_sort_sum_gradient) {
accumulator.reset(new SortedGradientAccumulator(var.get()));
} else {
accumulator.reset(new EagerGradientAccumulator(var.get()));
......
......@@ -18,7 +18,6 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"
......@@ -30,8 +29,7 @@ class OpBase;
class BasicEngine : public Engine {
public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph = false);
void Init(VarBase* var, bool retain_graph = false);
void Execute() override;
......@@ -46,7 +44,6 @@ class BasicEngine : public Engine {
private:
std::shared_ptr<GradOpNode> init_node_;
detail::BackwardStrategy backward_strategy_;
std::unordered_map<GradOpNode*, size_t> node_deps_;
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
......
......@@ -33,6 +33,8 @@
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"
DECLARE_bool(sort_sum_gradient);
namespace paddle {
namespace imperative {
......@@ -529,8 +531,7 @@ class PartialGradTask {
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place,
const detail::BackwardStrategy &strategy, bool create_graph,
const platform::Place &place, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs);
std::vector<std::shared_ptr<VarBase>> Run();
......@@ -577,7 +578,6 @@ class PartialGradTask {
bool retain_graph_;
bool allow_unused_;
bool only_inputs_;
detail::BackwardStrategy strategy_;
};
PartialGradTask::PartialGradTask(
......@@ -585,15 +585,14 @@ PartialGradTask::PartialGradTask(
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) {
const platform::Place &place, bool create_graph, bool retain_graph,
bool allow_unused, bool only_inputs) {
input_targets_ = input_targets;
place_ = place;
create_graph_ = create_graph;
retain_graph_ = retain_graph;
allow_unused_ = allow_unused;
only_inputs_ = only_inputs;
strategy_ = strategy;
PADDLE_ENFORCE_EQ(only_inputs_, true,
platform::errors::Unimplemented(
......@@ -981,7 +980,7 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) {
if (!accumulator) {
accumulator.reset(new GradientAccumulationInfo(
var, strategy_.sorted_sum_gradient_, create_graph_));
var, FLAGS_sort_sum_gradient, create_graph_));
}
accumulator->IncreaseTotalRefCnt();
......@@ -1033,11 +1032,11 @@ PartialGradEngine::PartialGradEngine(
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
const platform::Place &place, bool create_graph, bool retain_graph,
bool allow_unused, bool only_inputs)
: task_(new PartialGradTask(input_targets, output_targets, output_grads,
no_grad_vars, place, strategy, create_graph,
retain_graph, allow_unused, only_inputs)) {}
no_grad_vars, place, create_graph, retain_graph,
allow_unused, only_inputs)) {}
PartialGradEngine::~PartialGradEngine() { Clear(); }
......
......@@ -16,7 +16,6 @@
#include <memory>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/platform/place.h"
......@@ -33,8 +32,7 @@ class PartialGradEngine : public Engine {
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place,
const detail::BackwardStrategy &strategy, bool create_graph,
const platform::Place &place, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs);
~PartialGradEngine();
......
......@@ -240,9 +240,8 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
framework::AttributeMap reduce_attr_map;
tracer.TraceOp("reduce_sum", reduce_in, reduce_out, reduce_attr_map,
gpu_place, true);
detail::BackwardStrategy back_st;
imperative::BasicEngine engine;
engine.Init(reduce_sum_out.get(), back_st);
engine.Init(reduce_sum_out.get());
engine.Execute();
framework::LoDTensor rlt;
......@@ -356,9 +355,8 @@ TEST(test_tracer, test_var_without_grad_var) {
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL);
detail::BackwardStrategy back_st;
imperative::BasicEngine engine;
engine.Init(vout.get(), back_st);
engine.Init(vout.get());
engine.Execute();
// check the grad
......
......@@ -15,7 +15,6 @@
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"
......@@ -103,8 +102,8 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// params_file_ fields.
CP_MEMBER(opt_cache_dir_);
prog_file_ = std::move(other.prog_file_);
params_file_ = std::move(other.params_file_);
CP_MEMBER(prog_file_);
CP_MEMBER(params_file_);
CP_MEMBER(use_fc_padding_);
// GPU related.
......@@ -218,6 +217,17 @@ void AnalysisConfig::EnableMkldnnQuantizer() {
Update();
}
void AnalysisConfig::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
use_mkldnn_bfloat16_ = true;
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
use_mkldnn_bfloat16_ = false;
#endif
Update();
}
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
"MkldnnQuantizer was not enabled yet.");
......@@ -331,6 +341,12 @@ void AnalysisConfig::Update() {
#endif
}
if (use_mkldnn_bfloat16_) {
#ifdef PADDLE_WITH_MKLDNN
pass_builder()->EnableMkldnnBfloat16();
#endif
}
#ifdef PADDLE_WITH_MKLDNN
// Do not optimize when mkldnn is on
if (enable_memory_optim_ && !use_mkldnn_) {
......@@ -399,6 +415,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << ";";
ss << use_mkldnn_quantizer_;
ss << use_mkldnn_bfloat16_;
ss << model_from_memory_;
ss << with_profile_;
......
......@@ -32,7 +32,6 @@
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/memory/memcpy.h"
......@@ -517,6 +516,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) {
// TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config>
if (config.glog_info_disabled()) {
FLAGS_logtostderr = 1;
FLAGS_minloglevel = 2; // GLOG_ERROR
......@@ -1058,3 +1059,122 @@ USE_TRT_CONVERTER(skip_layernorm);
USE_TRT_CONVERTER(slice);
USE_TRT_CONVERTER(scale);
#endif
namespace paddle_infer {
void Tensor::Reshape(const std::vector<int> &shape) { tensor_->Reshape(shape); }
std::vector<int> Tensor::shape() const { return tensor_->shape(); }
void Tensor::SetLoD(const std::vector<std::vector<size_t>> &x) {
return tensor_->SetLoD(x);
}
std::vector<std::vector<size_t>> Tensor::lod() const { return tensor_->lod(); }
const std::string &Tensor::name() const { return tensor_->name(); }
DataType Tensor::type() const { return tensor_->type(); }
Predictor::Predictor(const Config &config) {
const_cast<Config *>(&config)->SwitchUseFeedFetchOps(false);
// The second parameter indicates that the discard log is not printed
predictor_ = paddle::CreatePaddlePredictor<
Config, paddle::PaddleEngineKind::kAnalysis>(config);
}
std::vector<std::string> Predictor::GetInputNames() {
return predictor_->GetInputNames();
}
std::unique_ptr<Tensor> Predictor::GetInputHandle(const std::string &name) {
auto zero_copy_tensor = predictor_->GetInputTensor(name);
std::unique_ptr<Tensor> tensor(new Tensor(std::move(zero_copy_tensor)));
return tensor;
}
std::vector<std::string> Predictor::GetOutputNames() {
return predictor_->GetOutputNames();
}
std::unique_ptr<Tensor> Predictor::GetOutputHandle(const std::string &name) {
auto zero_copy_tensor = predictor_->GetOutputTensor(name);
std::unique_ptr<Tensor> tensor(new Tensor(std::move(zero_copy_tensor)));
return tensor;
}
bool Predictor::Run() { return predictor_->ZeroCopyRun(); }
std::unique_ptr<Predictor> Predictor::Clone() {
auto analysis_pred = predictor_->Clone();
std::unique_ptr<Predictor> pred(new Predictor(std::move(analysis_pred)));
return pred;
}
void Predictor::ClearIntermediateTensor() {
predictor_->ClearIntermediateTensor();
}
int GetNumBytesOfDataType(DataType dtype) {
switch (dtype) {
case DataType::FLOAT32:
return sizeof(float);
case DataType::INT64:
return sizeof(int64_t);
case DataType::INT32:
return sizeof(int32_t);
case DataType::UINT8:
return sizeof(uint8_t);
default:
assert(false);
return -1;
}
}
std::string GetVersion() { return paddle::get_version(); }
std::string UpdateDllFlag(const char *name, const char *value) {
return paddle::UpdateDllFlag(name, value);
}
} // namespace paddle_infer
namespace paddle_infer {
std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
std::shared_ptr<Predictor> predictor(new Predictor(config));
return predictor;
}
namespace services {
PredictorPool::PredictorPool(const Config &config, size_t size) {
PADDLE_ENFORCE_GE(
size, 1UL,
paddle::platform::errors::InvalidArgument(
"The predictor pool size should be greater than 1, but it's (%d)",
size));
Config copy_config(config);
main_pred_.reset(new Predictor(config));
for (size_t i = 0; i < size - 1; i++) {
if (config.tensorrt_engine_enabled()) {
Config config_tmp(copy_config);
preds_.push_back(
std::move(std::unique_ptr<Predictor>(new Predictor(config_tmp))));
} else {
preds_.push_back(std::move(main_pred_->Clone()));
}
}
}
Predictor *PredictorPool::Retrive(size_t idx) {
PADDLE_ENFORCE_LT(
idx, preds_.size() + 1,
paddle::platform::errors::InvalidArgument(
"There are (%d) predictors in the pool, but the idx is (%d)", idx,
preds_.size() + 1));
if (idx == 0) {
return main_pred_.get();
}
return preds_[idx - 1].get();
}
} // namespace services
} // namespace paddle_infer
......@@ -485,4 +485,25 @@ TEST_F(MkldnnQuantizerTest, kl_scaling_factor_unsigned) {
}
#endif
#ifdef PADDLE_WITH_CUDA
TEST(AnalysisPredictor, bf16_gpu_pass_strategy) {
AnalysisConfig config;
config.SetModel(FLAGS_dirname);
config.SwitchIrOptim(true);
config.EnableUseGpu(100, 0);
config.EnableMkldnnBfloat16();
#ifdef PADDLE_WITH_MKLDNN
ASSERT_EQ(config.mkldnn_bfloat16_enabled(), true);
#else
ASSERT_EQ(config.mkldnn_bfloat16_enabled(), false);
#endif
}
#endif
TEST(AnalysisPredictor, bf16_pass_strategy) {
std::vector<std::string> passes;
PassStrategy passStrategy(passes);
passStrategy.EnableMkldnnBfloat16();
}
} // namespace paddle
......@@ -112,6 +112,12 @@ void PaddleBuf::Free() {
}
}
NativeConfig::NativeConfig() {
LOG(WARNING) << "The paddle::NativeConfig interface is going to be "
"deprecated in the next release, plase use the latest "
"paddle_infer::Config instead.";
}
std::string get_version() {
std::stringstream ss;
ss << "version: " << framework::paddle_version() << "\n";
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <map>
#include <memory>
#include <set>
#include <sstream>
#include <string>
......@@ -25,6 +26,7 @@ limitations under the License. */
#include "paddle/fluid/inference/api/api_impl.h"
#include "paddle/fluid/inference/api/details/reset_tensor_array.h"
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -311,6 +313,8 @@ bool NativePaddlePredictor::GetFetch(std::vector<PaddleTensor> *outputs,
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
NativeConfig, PaddleEngineKind::kNative>(const NativeConfig &config) {
// TODO(NHZlX): Should add the link to the doc of
// paddle_infer::CreatePredictor<paddle_infer::Config>
VLOG(3) << "create NativePaddlePredictor";
if (config.use_gpu) {
// 1. GPU memory
......
......@@ -401,6 +401,19 @@ struct PD_INFER_DECL AnalysisConfig {
///
void EnableMkldnnQuantizer();
///
/// \brief Turn on MKLDNN bfloat16.
///
///
void EnableMkldnnBfloat16();
///
/// \brief A boolean state telling whether to use the MKLDNN Bfloat16.
///
/// \return bool Whether to use the MKLDNN Bfloat16.
///
bool mkldnn_bfloat16_enabled() const { return use_mkldnn_bfloat16_; }
///
/// \brief A boolean state telling whether the thread local CUDA stream is
/// enabled.
......@@ -592,6 +605,7 @@ struct PD_INFER_DECL AnalysisConfig {
int mkldnn_cache_capacity_{0};
bool use_mkldnn_quantizer_{false};
std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config_;
bool use_mkldnn_bfloat16_{false};
// If the config is already used on a predictor, it becomes invalid.
// Any config can only be used with one predictor.
......
......@@ -347,6 +347,7 @@ class PD_INFER_DECL PaddlePredictor {
/// place of inference, etc.)
///
struct PD_INFER_DECL NativeConfig : public PaddlePredictor::Config {
NativeConfig();
/// GPU related fields.
bool use_gpu{false};
int device{0};
......@@ -421,7 +422,8 @@ enum class PaddleEngineKind {
};
template <typename ConfigT, PaddleEngineKind engine>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
PD_INFER_DECL std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(
const ConfigT& config);
template <>
PD_INFER_DECL std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
......@@ -437,6 +439,4 @@ PD_INFER_DECL std::string get_version();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);
PD_INFER_DECL std::shared_ptr<framework::Cipher> MakeCipher(
const std::string& config_file);
} // namespace paddle
......@@ -22,9 +22,124 @@ limitations under the License. */
#pragma once
#include <cassert>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle_analysis_config.h" // NOLINT
#include "paddle_api.h" // NOLINT
namespace paddle_infer {
using DataType = paddle::PaddleDType;
using PlaceType = paddle::PaddlePlace;
using PrecisionType = paddle::AnalysisConfig::Precision;
using Config = paddle::AnalysisConfig;
class PD_INFER_DECL Tensor {
public:
// Can only be created by predictor->GetInputHandle(cosnt std::string& name)
// or predictor->GetOutputHandle(cosnt std::string& name)
Tensor() = delete;
explicit Tensor(std::unique_ptr<paddle::ZeroCopyTensor>&& tensor)
: tensor_(std::move(tensor)) {}
void Reshape(const std::vector<int>& shape);
template <typename T>
void CopyFromCpu(const T* data);
// should add the place
template <typename T>
T* mutable_data(PlaceType place);
template <typename T>
void CopyToCpu(T* data);
template <typename T>
T* data(PlaceType* place, int* size) const;
void SetLoD(const std::vector<std::vector<size_t>>& x);
std::vector<std::vector<size_t>> lod() const;
DataType type() const;
std::vector<int> shape() const;
const std::string& name() const;
private:
std::unique_ptr<paddle::ZeroCopyTensor> tensor_;
};
class PD_INFER_DECL Predictor {
public:
Predictor() = default;
~Predictor() {}
// Use for clone
explicit Predictor(std::unique_ptr<paddle::PaddlePredictor>&& pred)
: predictor_(std::move(pred)) {}
explicit Predictor(const Config& config);
std::vector<std::string> GetInputNames();
std::unique_ptr<Tensor> GetInputHandle(const std::string& name);
bool Run();
std::vector<std::string> GetOutputNames();
std::unique_ptr<Tensor> GetOutputHandle(const std::string& name);
std::unique_ptr<Predictor> Clone();
void ClearIntermediateTensor();
private:
std::unique_ptr<paddle::PaddlePredictor> predictor_;
};
PD_INFER_DECL std::shared_ptr<Predictor> CreatePredictor(
const Config& config); // NOLINT
PD_INFER_DECL int GetNumBytesOfDataType(DataType dtype);
PD_INFER_DECL std::string GetVersion();
PD_INFER_DECL std::string UpdateDllFlag(const char* name, const char* value);
template <typename T>
void Tensor::CopyFromCpu(const T* data) {
tensor_->copy_from_cpu<T>(data);
}
template <typename T>
void Tensor::CopyToCpu(T* data) {
return tensor_->copy_to_cpu<T>(data);
}
template <typename T>
T* Tensor::mutable_data(PlaceType place) {
return tensor_->mutable_data<T>(place);
}
template <typename T>
T* Tensor::data(PlaceType* place, int* size) const {
return tensor_->data<T>(place, size);
}
} // namespace paddle_infer
namespace paddle_infer {
namespace services {
class PD_INFER_DECL PredictorPool {
public:
PredictorPool() = delete;
PredictorPool(const PredictorPool&) = delete;
PredictorPool& operator=(const PredictorPool&) = delete;
explicit PredictorPool(const Config& config, size_t size = 1);
Predictor* Retrive(size_t idx);
private:
std::shared_ptr<Predictor> main_pred_;
std::vector<std::unique_ptr<Predictor>> preds_;
};
} // namespace services
} // namespace paddle_infer
......@@ -143,6 +143,10 @@ void GpuPassStrategy::EnableMkldnnQuantizer() {
LOG(ERROR) << "GPU not support MKL-DNN quantization";
}
void GpuPassStrategy::EnableMkldnnBfloat16() {
LOG(ERROR) << "GPU not support MKL-DNN bfloat16";
}
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones.
......@@ -181,12 +185,14 @@ void CpuPassStrategy::EnableMKLDNN() {
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : std::vector<std::string>({
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", //
"conv_transpose_bias_mkldnn_fuse_pass",
"conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass",
......@@ -223,4 +229,12 @@ void CpuPassStrategy::EnableMkldnnQuantizer() {
#endif
}
void CpuPassStrategy::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
use_mkldnn_bfloat16_ = true;
#else
use_mkldnn_bfloat16_ = false;
#endif
}
} // namespace paddle
......@@ -132,6 +132,9 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder {
/// \brief Enable MKLDNN quantize optimization.
virtual void EnableMkldnnQuantizer() {}
/// \brief Enable MKLDNN bfloat16.
virtual void EnableMkldnnBfloat16() {}
/// \brief Check if we are using gpu.
/// \return A bool variable implying whether we are in gpu mode.
bool use_gpu() const { return use_gpu_; }
......@@ -161,6 +164,7 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy {
use_gpu_ = other.use_gpu_;
use_mkldnn_ = other.use_mkldnn_;
use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_;
use_mkldnn_bfloat16_ = other.use_mkldnn_bfloat16_;
}
/// \brief Default destructor.
virtual ~CpuPassStrategy() = default;
......@@ -174,9 +178,13 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy {
/// \brief Enable MKLDNN quantize optimization.
void EnableMkldnnQuantizer() override;
/// \brief Enable MKLDNN bfloat16.
void EnableMkldnnBfloat16() override;
protected:
/// \cond Protected
bool use_mkldnn_quantizer_{false};
bool use_mkldnn_bfloat16_{false};
/// \endcond
};
......@@ -205,6 +213,9 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy {
/// \brief Not supported in GPU mode yet.
void EnableMkldnnQuantizer() override;
/// \brief Not supported in GPU mode yet.
void EnableMkldnnBfloat16() override;
/// \brief Default destructor.
virtual ~GpuPassStrategy() = default;
......
......@@ -235,6 +235,12 @@ PADDLE_CAPI_EXPORT extern void PD_EnableMkldnnQuantizer(
PADDLE_CAPI_EXPORT extern bool PD_MkldnnQuantizerEnabled(
const PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern void PD_EnableMkldnnBfloat16(
PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern bool PD_MkldnnBfloat16Enabled(
const PD_AnalysisConfig* config);
PADDLE_CAPI_EXPORT extern void PD_SetModelBuffer(PD_AnalysisConfig* config,
const char* prog_buffer,
size_t prog_buffer_size,
......
......@@ -207,6 +207,18 @@ bool PD_MkldnnQuantizerEnabled(const PD_AnalysisConfig* config) {
return config->config.mkldnn_quantizer_enabled();
}
void PD_EnableMkldnnBfloat16(PD_AnalysisConfig* config) {
PADDLE_ENFORCE_NOT_NULL(config, paddle::platform::errors::NotFound(
"PD_AnalysisConfig should not be null"));
config->config.EnableMkldnnBfloat16();
}
bool PD_MkldnnBfloat16Enabled(const PD_AnalysisConfig* config) {
PADDLE_ENFORCE_NOT_NULL(config, paddle::platform::errors::NotFound(
"PD_AnalysisConfig should not be null"));
return config->config.mkldnn_bfloat16_enabled();
}
void PD_SetModelBuffer(PD_AnalysisConfig* config, const char* prog_buffer,
size_t prog_buffer_size, const char* params_buffer,
size_t params_buffer_size) {
......
......@@ -51,7 +51,13 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("Input_scale"));
if (op_desc.Type() != "conv2d_transpose") {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("Input_scale"), true,
platform::errors::InvalidArgument("Input scale not found. TRT int8"
" requires conv/deconv to have "
"input quantization scales."));
}
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
......
......@@ -54,7 +54,7 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
auto ptr = new SkipLayerNormPluginDynamic(
bias_.data(), scale_.data(), bias_size_, scale_size_, eps_, ban_fp16_);
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
return ptr;
}
......
......@@ -192,7 +192,8 @@ download_result(${ERNIE_INSTALL_DIR} "Ernie_large_result.txt.tar.gz")
inference_analysis_test(test_analyzer_ernie_large SRCS analyzer_ernie_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${ERNIE_INSTALL_DIR}/model --infer_data=${ERNIE_INSTALL_DIR}/data.txt --refer_result=${ERNIE_INSTALL_DIR}/result.txt --ernie_large=true)
set_tests_properties(test_analyzer_ernie_large PROPERTIES TIMEOUT 150)
set_tests_properties(test_analyzer_ernie_large PROPERTIES TIMEOUT 150 LABELS "RUN_TYPE=NIGHTLY")
# text_classification
set(TEXT_CLASSIFICATION_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/text_classification")
......@@ -514,3 +515,9 @@ if(WITH_MKLDNN)
inference_analysis_test(test_analyzer_capi_ner SRCS analyzer_capi_ner_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${CHINESE_NER_INSTALL_DIR}/model)
if(WITH_GPU)
inference_analysis_test(paddle_infer_api_test SRCS paddle_infer_api_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${RESNET50_MODEL_DIR})
endif()
......@@ -54,6 +54,9 @@ TEST(PD_AnalysisConfig, use_gpu) {
PD_SwitchIrOptim(config, true);
bool ir_optim = PD_IrOptim(config);
CHECK(ir_optim) << "NO";
PD_EnableMkldnnBfloat16(config);
bool bfloat16_enable = PD_MkldnnBfloat16Enabled(config);
CHECK(!bfloat16_enable) << "NO";
PD_EnableTensorRtEngine(config, 1 << 20, 1, 3, Precision::kFloat32, false,
false);
bool trt_enable = PD_TensorrtEngineEnabled(config);
......
......@@ -88,6 +88,9 @@ TEST(PD_AnalysisConfig, profile_mkldnn) {
PD_EnableMkldnnQuantizer(config);
bool quantizer_enable = PD_MkldnnQuantizerEnabled(config);
CHECK(quantizer_enable) << "NO";
PD_EnableMkldnnBfloat16(config);
bool bfloat16_enable = PD_MkldnnBfloat16Enabled(config);
CHECK(bfloat16_enable) << "NO";
PD_SetMkldnnCacheCapacity(config, 0);
PD_SetModel(config, prog_file.c_str(), params_file.c_str());
PD_DeleteAnalysisConfig(config);
......
......@@ -72,3 +72,59 @@ TEST(AnalysisPredictor, use_gpu) {
} // namespace inference
} // namespace paddle
namespace paddle_infer {
TEST(Predictor, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "model";
Config config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir + "/model", model_dir + "/params");
config.EnableLiteEngine(PrecisionType::kFloat32);
auto predictor = CreatePredictor(config);
const int batch = 1;
const int channel = 3;
const int height = 318;
const int width = 318;
const int input_num = batch * channel * height * width;
std::vector<float> input(input_num, 1);
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, 318, 318});
input_t->CopyFromCpu(input.data());
predictor->Run();
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
size_t out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
std::vector<float> out_data;
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
const std::vector<float> truth_values = {
127.780396f, 738.16656f, 1013.2264f, -438.17206f, 366.4022f,
927.66187f, 736.2241f, -633.68567f, -329.92737f, -430.15637f,
-633.0639f, -146.54858f, -1324.2804f, -1349.3661f, -242.67671f,
117.44864f, -801.7251f, -391.51495f, -404.8202f, 454.16132f,
515.48206f, -133.03114f, 69.293076f, 590.09753f, -1434.6917f,
-1070.8903f, 307.0744f, 400.52573f, -316.12177f, -587.1265f,
-161.05742f, 800.3663f, -96.47157f, 748.708f, 868.17645f,
-447.9403f, 112.73656f, 1127.1992f, 47.43518f, 677.7219f,
593.1881f, -336.4011f, 551.3634f, 397.82474f, 78.39835f,
-715.4006f, 405.96988f, 404.25684f, 246.01978f, -8.430191f,
131.36617f, -648.0528f};
float* data_o = out_data.data();
for (size_t j = 0; j < out_num; j += 10) {
EXPECT_NEAR((data_o[j] - truth_values[j / 10]) / truth_values[j / 10], 0.,
10e-5);
}
}
} // namespace paddle_infer
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstring>
#include <numeric>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle_infer {
TEST(Predictor, use_gpu) {
LOG(INFO) << GetVersion();
UpdateDllFlag("conv_workspace_size_limit", "4000");
std::string model_dir = FLAGS_infer_model + "/model";
Config config;
config.SetModel(model_dir + "/model", model_dir + "/params");
config.EnableUseGpu(100, 0);
auto predictor = CreatePredictor(config);
auto pred_clone = predictor->Clone();
std::vector<int> in_shape = {1, 3, 318, 318};
int in_num = std::accumulate(in_shape.begin(), in_shape.end(), 1,
[](int &a, int &b) { return a * b; });
std::vector<float> input(in_num, 0);
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
input_t->Reshape(in_shape);
input_t->CopyFromCpu(input.data());
predictor->Run();
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
std::vector<float> out_data;
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
predictor->ClearIntermediateTensor();
}
TEST(PredictorPool, basic) {
LOG(INFO) << GetVersion();
UpdateDllFlag("conv_workspace_size_limit", "4000");
std::string model_dir = FLAGS_infer_model + "/model";
Config config;
config.SetModel(model_dir + "/model", model_dir + "/params");
config.EnableUseGpu(100, 0);
services::PredictorPool pred_pool(config, 4);
auto pred = pred_pool.Retrive(2);
std::vector<int> in_shape = {1, 3, 318, 318};
int in_num = std::accumulate(in_shape.begin(), in_shape.end(), 1,
[](int &a, int &b) { return a * b; });
std::vector<float> input(in_num, 0);
auto in_names = pred->GetInputNames();
auto input_t = pred->GetInputHandle(in_names[0]);
input_t->name();
input_t->Reshape(in_shape);
input_t->CopyFromCpu(input.data());
pred->Run();
auto out_names = pred->GetOutputNames();
auto output_t = pred->GetOutputHandle(out_names[0]);
auto out_type = output_t->type();
LOG(INFO) << GetNumBytesOfDataType(out_type);
if (out_type == DataType::FLOAT32) {
PlaceType place;
int size;
output_t->data<float>(&place, &size);
}
}
} // namespace paddle_infer
......@@ -41,7 +41,7 @@ TEST(AnalysisPredictor, use_gpu) {
SetFakeImageInput(&inputs_all, model_dir, false, "__model__", "");
std::vector<PaddleTensor> outputs;
for (auto& input : inputs_all) {
for (auto &input : inputs_all) {
ASSERT_TRUE(predictor->Run(input, &outputs));
predictor->ClearIntermediateTensor();
}
......@@ -49,3 +49,27 @@ TEST(AnalysisPredictor, use_gpu) {
} // namespace inference
} // namespace paddle
namespace paddle_infer {
TEST(PredictorPool, use_gpu) {
std::string model_dir = FLAGS_infer_model + "/" + "mobilenet";
Config config;
config.EnableUseGpu(100, 0);
config.SetModel(model_dir);
config.EnableTensorRtEngine();
services::PredictorPool pred_pool(config, 1);
auto predictor = pred_pool.Retrive(0);
auto input_names = predictor->GetInputNames();
auto input_t = predictor->GetInputHandle(input_names[0]);
std::vector<int> in_shape = {1, 3, 224, 224};
int in_num = std::accumulate(in_shape.begin(), in_shape.end(), 1,
[](int &a, int &b) { return a * b; });
std::vector<float> input(in_num, 0);
input_t->Reshape(in_shape);
input_t->CopyFromCpu(input.data());
predictor->Run();
}
} // namespace paddle_infer
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
......@@ -1231,3 +1232,24 @@ REGISTER_OP_CPU_KERNEL(
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::AbsGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(leaky_relu)
.AddCheckpoint(
R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC",
paddle::framework::compatible::OpVersionDesc()
.BugfixWithBehaviorChanged(
"leaky_relu calculate formula before checkponit: out = max(x, "
"alpha * x); after checkpoint: out = x if x > 0 else alpha * "
"x"));
REGISTER_OP_VERSION(hard_shrink)
.AddCheckpoint(
R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC",
paddle::framework::compatible::OpVersionDesc()
.BugfixWithBehaviorChanged(
"hard_shrink calculate formula before checkponit: out = x * "
"((x < -threshold) + (x > threshold)); after checkpoint: out = "
"x * (((x < -threshold) + (x > threshold)) > 0)"));
/* ========================================================================== */
......@@ -64,11 +64,11 @@ class BernoulliOpKernel<platform::CPUDeviceContext, T>
int64_t size = x->numel();
std::uniform_real_distribution<T> dist(0.0, 1.0);
auto gen_ptr = framework::Generator::GetInstance();
std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine();
auto gen_ptr = framework::DefaultCPUGenerator();
auto engine = gen_ptr->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
out_data[i] = BernoulliFunctor(in_data[i], dist(gen_engine));
out_data[i] = BernoulliFunctor(in_data[i], dist(*engine));
}
}
}; // namespace operators
......
......@@ -66,7 +66,7 @@ template <typename DeviceContext, typename T>
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = static_cast<T>(context.Attr<float>("max"));
auto max = context.Attr<T>("max");
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
......@@ -77,9 +77,8 @@ class ClipKernel : public framework::OpKernel<T> {
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
auto min = context.Attr<T>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......@@ -90,11 +89,12 @@ class ClipKernel : public framework::OpKernel<T> {
}
min = min_data[0];
}
min = static_cast<T>(min);
PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
"max should be greater than min. "
"But received min = %f, max = %f",
min, max));
PADDLE_ENFORCE_LE(min, max,
platform::errors::InvalidArgument(
"max should be greater than or equal to min. "
"But received min = %f, max = %f",
min, max));
auto* x_var = context.InputVar("X");
if (x_var->IsType<framework::LoDTensor>()) {
......@@ -141,7 +141,7 @@ template <typename DeviceContext, typename T>
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = static_cast<T>(context.Attr<float>("max"));
auto max = context.Attr<T>("max");
Tensor max_cpu;
if (context.HasInput("Max")) {
auto* max_t = context.Input<Tensor>("Max");
......@@ -152,9 +152,8 @@ class ClipGradKernel : public framework::OpKernel<T> {
}
max = max_data[0];
}
max = static_cast<T>(max);
auto min = context.Attr<float>("min");
auto min = context.Attr<T>("min");
Tensor min_cpu;
if (context.HasInput("Min")) {
auto* min_t = context.Input<Tensor>("Min");
......@@ -165,7 +164,6 @@ class ClipGradKernel : public framework::OpKernel<T> {
}
min = min_data[0];
}
min = static_cast<T>(min);
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -97,19 +99,19 @@ class BinaryLogicalOp : public LogicalOp {
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
int product_x = framework::product(dim_x);
int product_y = framework::product(dim_y);
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
if (check) {
PADDLE_ENFORCE_EQ(product_x, product_y,
platform::errors::InvalidArgument(
"The number of elements in X and Y should be same, "
"but received %d != %d",
product_x, product_y));
if (dim_x == dim_y) {
context->SetOutputDim("Out", dim_x);
} else {
int max_dim = std::max(dim_x.size(), dim_y.size());
int axis = std::abs(dim_x.size() - dim_y.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
context->SetOutputDim("Out", framework::make_ddim(out_dims_array));
}
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <math.h>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
......@@ -57,10 +58,8 @@ class BinaryLogicalOpKernel
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), y->data<T>(),
out->mutable_data<bool>(context.GetPlace()), binary_func);
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, -1,
binary_func, out);
}
};
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/cudnn_workspace_helper.h"
#ifdef PADDLE_WITH_MKLDNN
......@@ -567,3 +568,14 @@ REGISTER_OP_CPU_KERNEL(
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
REGISTER_OP_VERSION(conv_transpose)
.AddCheckpoint(
R"ROC(
Upgrade convtranspose add a new attribute [output_padding].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
{}));
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/cum_op.h"
namespace paddle {
......@@ -95,3 +96,14 @@ REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>,
ops::CumKernel<CPU, ops::CumsumFunctor<double>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int64_t>>);
REGISTER_OP_VERSION(cumsum)
.AddCheckpoint(
R"ROC(
Upgrade cumsum add a new attribute [flatten].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"flatten",
"In order to compute the cumsum over the flattened array when the "
"argument `axis` in python API is None.",
false));
......@@ -56,7 +56,7 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op)
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op scale_op)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
......@@ -132,6 +132,15 @@ void ProcGetResponse(const VarHandle& var_h,
&trainer_id);
}
void ProcGetRecvResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
VLOG(4) << "ProcGetRecvResponse";
framework::Variable* outvar = nullptr;
int trainer_id;
DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
&trainer_id);
}
template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong());
......@@ -482,6 +491,79 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
return h;
}
VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& send_var_name,
const std::string& recv_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string send_var_name_val = send_var_name;
const std::string recv_var_name_val = recv_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kSendAndRecvRPC;
VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: "
<< send_var_name_val << " Recv_var_name: " << recv_var_name_val;
int retry_times_ = 0;
while (true) {
SendAndRecvProcessor* s = new SendAndRecvProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope));
VarHandlePtr h_recv(
new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
s->RecvPrepare(h_recv);
framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val,
p_scope, p_ctx, s, method, h, this] {
auto* send_var = p_scope->FindVar(send_var_name_val);
send_var->GetMutable<framework::LoDTensor>()->set_lod({});
::grpc::ByteBuffer buf;
VLOG(4) << "SerializeToByteBuffer: send_var_name_val: "
<< send_var_name_val
<< " recv_var_name_val: " << recv_var_name_val;
SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf,
recv_var_name_val, trainer_id_, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetRecvResponse;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable",
buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
return h;
}
}
bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
......
......@@ -53,6 +53,8 @@ namespace distributed {
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
void ProcGetRecvResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
BaseProcessor() { context_ = nullptr; }
......@@ -131,6 +133,28 @@ class GetProcessor : public BaseProcessor {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};
class SendAndRecvProcessor : public BaseProcessor {
public:
explicit SendAndRecvProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(), stub_g_(ch) {}
virtual ~SendAndRecvProcessor() {}
void ProcessImpl() override {
if (response_call_back_) {
response_call_back_(*var_h_recv_.get(), reply_);
var_h_recv_->Finish(true);
}
}
void RecvPrepare(VarHandlePtr h_recv) { var_h_recv_ = h_recv; }
::grpc::ByteBuffer reply_;
::grpc::GenericStub stub_g_;
RequestGetCallBack response_call_back_ = ProcGetResponse;
VarHandlePtr var_h_recv_;
};
class BatchBarrierProcessor : public BaseProcessor {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
......@@ -231,6 +255,14 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendAndRecv(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& send_var_name,
const std::string& recv_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
......
......@@ -76,7 +76,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
}
std::string header;
request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]);
......@@ -101,7 +100,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
}
#endif
PADDLE_ENFORCE_NOT_NULL(payload);
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->memory_size());
if (payload->memory_size() >= std::numeric_limits<int>::max()) {
......@@ -140,7 +138,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
::grpc::Slice::STEAL_REF);
num_slices = 4;
}
::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp);
}
......@@ -156,6 +153,19 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
*trainer_id = resp.GetTrainerId();
}
void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial");
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE_EQ(
resp.Parse(msg), 0,
platform::errors::InvalidArgument("parse bytebuffer to tensor error!"));
*var = resp.GetRecvVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -47,6 +47,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -28,6 +28,7 @@ DECLARE_int32(rpc_retry_bind_port);
namespace paddle {
namespace operators {
namespace distributed {
enum CallStatus { PROCESS = 0, FINISH };
// reference:
......@@ -433,6 +434,51 @@ class RequestNotify final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
class RequestSendAndRecv final : public RequestBase {
public:
explicit RequestSendAndRecv(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(
request_handler->scope(), request_handler->dev_ctx(),
request_handler->distributed_mode()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kRequestSendAndRecv);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestSendAndRecv() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSendAndRecv, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = nullptr;
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is waiting server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_);
......@@ -586,6 +632,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) {
b = new RequestNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestSendAndRecv) {
b = new RequestSendAndRecv(service_.get(), cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not supported rpc");
}
......
......@@ -85,10 +85,12 @@ enum class GrpcMethod {
kGetMonomerVariable,
kGetMonomerBarrier,
kRequestNotify,
kRequestSendAndRecv,
// when you add new handler, change kGrpcNumMethods at the same time!
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kRequestNotify) + 1;
static_cast<int>(GrpcMethod::kRequestSendAndRecv) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
......@@ -108,6 +110,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/CheckpointNotify";
case GrpcMethod::kRequestNotify:
return "/sendrecv.SendRecvService/DistributeNotify";
case GrpcMethod::kRequestSendAndRecv:
return "/sendrecv.SendRecvService/SendAndRecvVariable";
}
// Shouldn't be reached.
......
......@@ -14,20 +14,19 @@
#pragma once
#include <ThreadPool.h>
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h"
......@@ -89,26 +88,17 @@ class UniformInitializer : public Initializer {
min_ = std::stof(attrs[2]);
max_ = std::stof(attrs[3]);
if (seed_ == 0) {
seed_ = std::random_device()();
}
random_engine_.seed(seed_);
dist_ = std::uniform_real_distribution<float>(min_, max_);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override {
return framework::Generator::GetInstance()->is_init_py
? dist_(framework::Generator::GetInstance()->GetCPUEngine())
: dist_(random_engine_);
// return dist_(random_engine_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float min_;
float max_;
std::minstd_rand random_engine_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
......@@ -139,26 +129,18 @@ class GaussianInitializer : public Initializer {
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
if (seed_ == 0) {
seed_ = std::random_device()();
}
random_engine_ = framework::GetCPURandomEngine(seed_);
random_engine_.seed(seed_);
dist_ = std::normal_distribution<float>(mean_, std_);
}
float GetValue() override {
return framework::Generator::GetInstance()->is_init_py
? dist_(framework::Generator::GetInstance()->GetCPUEngine())
: dist_(random_engine_);
// return dist_(random_engine_);
}
float GetValue() override { return dist_(*random_engine_); }
private:
float std_;
float mean_;
std::minstd_rand random_engine_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::normal_distribution<float> dist_;
};
......
......@@ -46,6 +46,7 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kRequestNotify[] = "RequestNotify";
constexpr char kRequestSendAndRecv[] = "RequestSendAndRecv";
constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
......@@ -57,6 +58,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
constexpr char kSendAndRecvRPC[] = "SendAndRecvRPC";
constexpr int64_t kPrefetchTimeout = 60000;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
......
......@@ -325,6 +325,22 @@ bool RequestNotifyHandler::Handle(const std::string &varname,
return true;
}
bool RequestSendAndRecvHandler::Handle(const std::string &varname,
framework::Scope *Scope,
framework::Variable *var,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "SendAndRecvHandle: " << varname
<< " out_var_name: " << out_var_name
<< " , trainer_id: " << trainer_id;
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), Scope);
*outvar = Scope->FindVar(out_var_name);
return true;
}
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -176,6 +176,17 @@ class RequestNotifyHandler final : public RequestHandler {
std::unordered_map<int, int64_t> decay_counters;
};
class RequestSendAndRecvHandler final : public RequestHandler {
public:
explicit RequestSendAndRecvHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestSendAndRecvHandler() {}
bool Handle(const std::string& varname, framework::Scope* Scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -85,6 +85,12 @@ class RPCClient {
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendAndRecv(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& send_var_name,
const std::string& recv_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
......
......@@ -35,27 +35,24 @@ namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table_read);
USE_OP(scale);
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block);
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}});
framework::VariableNameMap output({{"Output", {"out"}}});
auto op = block->AppendOp();
op->SetType("lookup_sparse_table_read");
op->SetInput("W", {"w"});
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"});
op->SetAttr("tablename", {"w"});
op->SetAttr("value_names", {"Param"});
auto& out = *root_block->Var("out");
framework::OpDesc* op = block->AppendOp();
op->SetType("scale");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"res"});
op->SetAttr("scale", 0.5f);
auto& out = *root_block->Var("res");
out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({10, 10});
out.SetShape({1, 10});
return block;
}
......@@ -69,6 +66,12 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::LoDTensor>();
auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>();
auto res_var = scope->Var("res");
res_var->GetMutable<framework::LoDTensor>();
}
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
......@@ -78,6 +81,11 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t* ids_ptr =
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
}
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
......@@ -124,6 +132,38 @@ void StartServer(const std::string& rpc_name) {
server_thread.join();
}
void StartSendAndRecvServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto block = AppendSendAndRecvBlock(&program);
std::string in_var_name("x");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx;
grad_to_prepared_ctx[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetGradToPreparedCtx(&grad_to_prepared_ctx);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
......@@ -147,3 +187,46 @@ TEST(COMPLETE, CPU) {
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestSendAndRecvHandler(
distributed::DistributedMode::kAsync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartSendAndRecvServer,
distributed::kRequestSendAndRecv);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
int64_t rows_numel = 10;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("x");
std::string out_var_name("res");
client->AsyncSendAndRecv(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[i], 0.5);
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
......@@ -29,7 +29,7 @@ service SendRecvService {
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
rpc DistributeNotify(VariableMessage) returns (VoidMessage) {}
rpc SendAndRecvVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
}
......
......@@ -96,6 +96,13 @@ class VariableResponse {
return scope_->FindVar(meta_.varname());
}
framework::Variable* GetRecvVar() {
if (create_scope_) {
return local_scope_->Var(meta_.out_varname());
}
return scope_->FindVar(meta_.out_varname());
}
int GetTrainerId() { return static_cast<int>(meta_.trainer_id()); }
protected:
......
......@@ -268,7 +268,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid);
......@@ -295,6 +294,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_send_and_recv_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
while (true) {
if (rpc_service_->IsExit()) {
......@@ -394,6 +394,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset(
new distributed::RequestNotifyHandler(distributed_mode, fan_in));
request_send_and_recv_handler_.reset(
new distributed::RequestSendAndRecvHandler(distributed_mode));
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), rpc_send_thread_num);
......@@ -408,6 +410,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), rpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestSendAndRecv,
request_send_and_recv_handler_.get(),
rpc_get_thread_num);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......@@ -416,6 +421,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
"optimize blocks is less than 1. Optimize blocks "
"should be 1 at least on the pserver side."));
auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
......@@ -488,6 +494,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get());
f(request_notify_handler_.get());
f(request_send_and_recv_handler_.get());
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal(SIGINT, SignalHandler::StopAndExit);
......
......@@ -99,6 +99,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RequestHandler>
request_checkpoint_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_notify_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_send_and_recv_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SendAndRecvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& scope = ctx.scope();
const auto& place = ctx.GetPlace();
auto send_var_name = ctx.Attr<std::string>("send_var_name");
auto recv_var_name = ctx.Attr<std::string>("recv_var_name");
auto epmap = ctx.Attr<std::string>("endpoint");
auto trainer_id = ctx.Attr<int>("trainer_id");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& context = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
VLOG(3) << "SendAndRecvOp Send_var_name: " << send_var_name
<< " Recv_var_name: " << recv_var_name;
distributed::VarHandlePtr rets = rpc_client->AsyncSendAndRecv(
epmap, context, scope, send_var_name, recv_var_name);
rets->Wait();
}
};
class SendAndRecvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "Tensor Input variable to be sent").AsDuplicable();
AddOutput("Out", "Tensor Output varibale to be recv").AsDuplicable();
AddAttr<std::string>("send_var_name", "Send Tensor's name")
.SetDefault(std::string(""));
AddAttr<std::string>("recv_var_name", "Recv Tensor's name")
.SetDefault(std::string(""));
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::string>("endpoint", "Server endpoint")
.SetDefault({"127.0.0.1:6164"});
AddComment(R"DOC(
SendAndRecv operator
This operator will send variables to listen_and_serve op at the parameter server.
And recv variable from parameter server of send variable's scope.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker);
REGISTER_OP_CPU_KERNEL(
send_and_recv,
ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, float>)
......@@ -55,30 +55,22 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT
return;
}
bool init_generator_py = framework::Generator::GetInstance()->is_init_py;
// std::minstd_rand engine;
// NOTE: fixed seed should only be used in unittest or for debug.
// Guarantee to use random seed in training.
std::random_device rnd;
std::minstd_rand engine;
int seed_data;
int seed_data = 0;
if (seed) {
seed_data = *(seed->data<int>());
} else {
seed_data =
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : rnd();
context.Attr<bool>("fix_seed") ? context.Attr<int>("seed") : 0;
}
engine.seed(seed_data);
auto engine = framework::GetCPURandomEngine(seed_data);
std::uniform_real_distribution<float> dist(0, 1);
for (size_t i = 0; i < size; ++i) {
float cur_random =
init_generator_py
? dist(framework::Generator::GetInstance()->GetCPUEngine())
: dist(engine);
if (cur_random < dropout_prob) {
if (dist(*engine) < dropout_prob) {
mask_data[i] = 0;
y_data[i] = 0;
} else {
......
......@@ -26,14 +26,34 @@ namespace operators {
template <typename T>
struct FloorDivFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return static_cast<T>(floor(a / b));
#ifdef __CUDA_ARCH__
if (b == 0) {
printf("Error: Divide by zero encounter in floor_divide\n");
asm("trap;");
}
#else
if (b == 0)
PADDLE_THROW(platform::errors::InvalidArgument(
"Divide by zero encounter in floor_divide"));
#endif
return static_cast<T>(std::trunc(a / b));
}
};
template <typename T>
struct InverseFloorDivFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
return static_cast<T>(floor(b / a));
#ifdef __CUDA_ARCH__
if (a == 0) {
printf("Error: Divide by zero encounter in floor_divide\n");
asm("trap;");
}
#else
if (a == 0)
PADDLE_THROW(platform::errors::InvalidArgument(
"Divide by zero encounter in floor_divide"));
#endif
return static_cast<T>(std::trunc(b / a));
}
};
......
......@@ -24,13 +24,19 @@ namespace operators {
template <typename T>
struct ModFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a % b; }
inline HOSTDEVICE T operator()(T a, T b) const {
T res = a % b;
if ((res != 0) && ((res < 0) != (b < 0))) res += b;
return res;
}
};
template <typename T>
struct ModFunctorFP {
inline HOSTDEVICE T operator()(T a, T b) const {
return fmod(b + fmod(a, b), b);
T res = fmod(a, b);
if ((res != 0) && ((b < 0) != (res < 0))) res += b;
return res;
}
};
......
......@@ -22,15 +22,20 @@ namespace operators {
template <typename T>
struct PowFunctor {
inline HOSTDEVICE T operator()(T a, T b) const {
#ifdef __CUDA_ARCH__
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
// TODO(wujionghao): A potential speed improvement is supporting different
// types in C++.
// #ifdef __CUDA_ARCH__
// // On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// // it will return a float number like 2.99... , which floor to 2
// // when cast to int by default and it is wrong.
// // Use llrint to cast it to the nearest integer, which is 3.
// if (std::is_integral<T>::value) {
// return std::llrint(std::pow(a, b));
// }
// #endif
if (std::is_integral<T>::value) {
return std::llrint(std::pow(a, b));
}
#endif
return std::pow(a, b);
}
};
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -152,3 +152,7 @@ REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>);
REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add attribut [axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"axis", "Specify the axis of gather operation.", {}));
......@@ -39,26 +39,14 @@ class CPUGaussianRandomKernel : public framework::OpKernel<T> {
tensor->Resize(shape);
int64_t size = tensor->numel();
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
if (framework::Generator::GetInstance()->is_init_py) {
std::mt19937_64& gen_engine =
framework::Generator::GetInstance()->GetCPUEngine();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(gen_engine);
}
} else {
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
};
}; // namespace operators
template <typename T>
class CPUGaussianRandomBatchSizeLikeKernel : public framework::OpKernel<T> {
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
......@@ -26,8 +26,8 @@ namespace math {
// TODO(wanghaoshuang): Support for GPU
/**
* Sample integers from [0, range).
*/
* Sample integers from [0, range).
*/
class Sampler {
public:
explicit Sampler(int64_t range, unsigned int seed = 0UL) : range_(range) {
......@@ -117,7 +117,7 @@ class CustomSampler : public Sampler {
const int* alias_;
const float* probs_;
const int exceptional_val = -1;
std::shared_ptr<std::mt19937> random_engine_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::shared_ptr<std::uniform_real_distribution<>> real_dist_;
std::shared_ptr<std::uniform_int_distribution<>> int_dist_;
};
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
文件模式从 100644 更改为 100755
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
文件模式从 100644 更改为 100755
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册