提交 9e60c586 编写于 作者: P peizhilin

Merge remote-tracking branch 'upstream/develop' into windows/mkl

test=develop
...@@ -208,6 +208,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act ...@@ -208,6 +208,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.huber_loss ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
...@@ -350,6 +351,22 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b ...@@ -350,6 +351,22 @@ paddle.fluid.contrib.QuantizeTranspiler.__init__ ArgSpec(args=['self', 'weight_b
paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.contrib.QuantizeTranspiler.convert_to_int8 ArgSpec(args=['self', 'program', 'place', 'scope'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None)) paddle.fluid.contrib.QuantizeTranspiler.freeze_program ArgSpec(args=['self', 'program', 'place', 'fuse_bn', 'scope'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.contrib.QuantizeTranspiler.training_transpile ArgSpec(args=['self', 'program', 'startup_program'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.contrib.load_persistables_for_increment ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var', 'lookup_table_var_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.load_persistables_for_inference ArgSpec(args=['dirname', 'executor', 'program', 'lookup_table_var_name'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.convert_dist_to_sparse_program ArgSpec(args=['program'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.__init__ ArgSpec(args=['self', 'hadoop_home', 'configs'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.delete ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.download ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'unzip'], varargs=None, keywords=None, defaults=(False, False))
paddle.fluid.contrib.HDFSClient.is_dir ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.HDFSClient.is_exist ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.contrib.HDFSClient.ls ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.lsr ArgSpec(args=['self', 'hdfs_path', 'only_file', 'sort'], varargs=None, keywords=None, defaults=(True, True))
paddle.fluid.contrib.HDFSClient.make_local_dirs ArgSpec(args=['local_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.makedirs ArgSpec(args=['self', 'hdfs_path'], varargs=None, keywords=None, defaults=None)
paddle.fluid.contrib.HDFSClient.rename ArgSpec(args=['self', 'hdfs_src_path', 'hdfs_dst_path', 'overwrite'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.contrib.HDFSClient.upload ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'retry_times'], varargs=None, keywords=None, defaults=(False, 5))
paddle.fluid.contrib.multi_download ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,))
paddle.fluid.contrib.multi_upload ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True))
paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.transpiler.DistributeTranspiler.__init__ ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None)
......
...@@ -131,9 +131,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy( ...@@ -131,9 +131,7 @@ std::shared_ptr<ir::PassBuilder> BuildStrategy::CreatePassesFromStrategy(
std::unique_ptr<ir::Graph> BuildStrategy::Apply( std::unique_ptr<ir::Graph> BuildStrategy::Apply(
const ProgramDesc &main_program, const std::vector<platform::Place> &places, const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::string &loss_var_name, const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const { const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const {
#else #else
...@@ -149,9 +147,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -149,9 +147,6 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->SetNotOwned<const std::vector<platform::Place>>("places", &places); pass->SetNotOwned<const std::vector<platform::Place>>("places", &places);
pass->Erase("loss_var_name"); pass->Erase("loss_var_name");
pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name); pass->SetNotOwned<const std::string>("loss_var_name", &loss_var_name);
pass->Erase("params");
pass->SetNotOwned<const std::unordered_set<std::string>>("params",
&param_names);
pass->Erase("local_scopes"); pass->Erase("local_scopes");
pass->SetNotOwned<const std::vector<Scope *>>("local_scopes", pass->SetNotOwned<const std::vector<Scope *>>("local_scopes",
&local_scopes); &local_scopes);
......
...@@ -106,16 +106,15 @@ struct BuildStrategy { ...@@ -106,16 +106,15 @@ struct BuildStrategy {
// Apply the passes built by the pass_builder_. The passes will be // Apply the passes built by the pass_builder_. The passes will be
// applied to the Program and output an ir::Graph. // applied to the Program and output an ir::Graph.
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> Apply(const ProgramDesc &main_program,
const ProgramDesc &main_program, const std::vector<platform::Place> &places,
const std::vector<platform::Place> &places, const std::string &loss_var_name,
const std::string &loss_var_name, const std::vector<Scope *> &local_scopes,
const std::unordered_set<std::string> &param_names,
const std::vector<Scope *> &local_scopes,
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
const bool use_cuda, platform::NCCLContextMap *nccl_ctxs) const; const bool use_cuda,
platform::NCCLContextMap *nccl_ctxs) const;
#else #else
const bool use_cuda) const; const bool use_cuda) const;
#endif #endif
private: private:
......
...@@ -130,7 +130,6 @@ void AddOutputToLeafOps(ir::Graph *graph) { ...@@ -130,7 +130,6 @@ void AddOutputToLeafOps(ir::Graph *graph) {
static const char kLossVarName[] = "loss_var_name"; static const char kLossVarName[] = "loss_var_name";
static const char kPlaces[] = "places"; static const char kPlaces[] = "places";
static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes"; static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy"; static const char kStrategy[] = "strategy";
static const char kNumTrainers[] = "num_trainers"; static const char kNumTrainers[] = "num_trainers";
...@@ -147,9 +146,6 @@ void MultiDevSSAGraphBuilder::Init() const { ...@@ -147,9 +146,6 @@ void MultiDevSSAGraphBuilder::Init() const {
nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs"); nccl_ctxs_ = &Get<platform::NCCLContextMap>("nccl_ctxs");
#endif #endif
for (auto &p : Get<const std::unordered_set<std::string>>(kParams)) {
grad_names_.insert(GradVarName(p));
}
balance_vars_.resize(places_.size(), 0); balance_vars_.resize(places_.size(), 0);
if (strategy_.enable_data_balance_ && places_.size() == 1) { if (strategy_.enable_data_balance_ && places_.size() == 1) {
LOG(WARNING) << "It is no need to enable data balance when there is only " LOG(WARNING) << "It is no need to enable data balance when there is only "
...@@ -896,7 +892,6 @@ REGISTER_PASS(multi_devices_pass, ...@@ -896,7 +892,6 @@ REGISTER_PASS(multi_devices_pass,
paddle::framework::details::MultiDevSSAGraphBuilder) paddle::framework::details::MultiDevSSAGraphBuilder)
.RequirePassAttr(paddle::framework::details::kLossVarName) .RequirePassAttr(paddle::framework::details::kLossVarName)
.RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kParams)
.RequirePassAttr(paddle::framework::details::kLocalScopes) .RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy) .RequirePassAttr(paddle::framework::details::kStrategy)
.RequirePassAttr(paddle::framework::details::kNumTrainers); .RequirePassAttr(paddle::framework::details::kNumTrainers);
...@@ -102,7 +102,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -102,7 +102,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable std::string loss_var_name_; mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_; mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_; mutable std::vector<Scope *> local_scopes_;
mutable std::unordered_set<std::string> grad_names_;
mutable BuildStrategy strategy_; mutable BuildStrategy strategy_;
mutable std::unordered_map<std::string, VarDesc *> all_vars_; mutable std::unordered_map<std::string, VarDesc *> all_vars_;
......
...@@ -24,35 +24,6 @@ namespace paddle { ...@@ -24,35 +24,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
// The function keeps the graph consistent by replacing
// a node 'from' in the set of inputs nodes
// of the visited node by a node 'to'.
void CorrectGraphEdges(Graph* graph, Node* from, Node* to) {
for (auto& node : GraphTraits::DFS(*graph)) {
auto from_in_inputs =
std::find(std::begin(node.inputs), std::end(node.inputs), from);
if (from_in_inputs != std::end(node.inputs)) {
IR_NODE_LINK_TO(to, (&node));
auto inputs = node.Op()->Inputs();
using input_type = VariableNameMap::value_type;
std::for_each(std::begin(inputs), std::end(inputs),
[from, to, &node](const input_type& i) -> void {
auto param_names = i.second;
auto pi = std::find(std::begin(param_names),
std::end(param_names), from->Name());
if (pi != std::end(param_names)) {
node.Op()->SetInput(i.first, {to->Name()});
}
});
}
}
}
bool IsReachable(ir::Graph* graph, Node* from, Node* to) { bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
auto find_node = [](ir::Graph* graph, const Node* node) -> Node* { auto find_node = [](ir::Graph* graph, const Node* node) -> Node* {
for (auto n : graph->Nodes()) { for (auto n : graph->Nodes()) {
...@@ -99,25 +70,12 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) { ...@@ -99,25 +70,12 @@ bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
return false; return false;
} }
boost::optional<Node*> HasBias(const Node& op, const std::string& bias_name) { template <typename T>
auto bias_input_names = op.Op()->Inputs(); boost::optional<T> HasAttribute(const Node& op, const std::string& attr) {
auto bias_it = bias_input_names.find(bias_name); if (op.Op()->HasAttr(attr))
return boost::get<T>(op.Op()->GetAttr(attr));
if (bias_it != std::end(bias_input_names)) { else
bool has_bias = !bias_it->second.empty(); return boost::none;
if (has_bias) {
auto bias_names = bias_it->second;
auto bias_names_it =
std::find_if(std::begin(op.inputs), std::end(op.inputs),
[&bias_names](Node* n) -> bool {
return n->Name() == bias_names[0];
});
return *bias_names_it;
}
}
return boost::none;
} }
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle( ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
...@@ -151,40 +109,18 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( ...@@ -151,40 +109,18 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; if (!IsReachable(graph, elementwise_add_identity, conv_output)) return;
OpDesc op_desc; auto fuse_relu = HasAttribute<bool>(*conv_op, "fuse_relu");
op_desc.SetType("conv2d"); if (fuse_relu && *fuse_relu) return;
op_desc.SetInput("Input", {conv_input->Name()});
op_desc.SetInput("Filter", {conv_filter->Name()});
op_desc.SetInput("ResidualData", {elementwise_add_identity->Name()});
op_desc.SetOutput("Output", {conv_output->Name()});
auto conv_bias = HasBias(*conv_op, "Bias"); conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true);
if (conv_bias) { GraphSafeRemoveNodes(graph, {conv_output, elementwise_add_op});
op_desc.SetInput("Bias", {(*conv_bias)->Name()});
}
for (const auto& attr : conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = graph->CreateOpNode(&op_desc); IR_NODE_LINK_TO(elementwise_add_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out);
IR_NODE_LINK_TO(conv_input, fused_conv_op);
IR_NODE_LINK_TO(conv_filter, fused_conv_op);
IR_NODE_LINK_TO(elementwise_add_identity, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, conv_output);
if (conv_bias) {
IR_NODE_LINK_TO((*conv_bias), fused_conv_op);
}
CorrectGraphEdges(graph, elementwise_add_out, conv_output);
GraphSafeRemoveNodes(graph,
{elementwise_add_out, conv_op, elementwise_add_op});
(*fusion_stats)++; (*fusion_stats)++;
} }
...@@ -229,60 +165,33 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( ...@@ -229,60 +165,33 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
Node* projection_node; Node* projection_node;
Node* residual_conv_op; Node* residual_conv_op;
Node* residual_conv_input;
Node* residual_conv_filter;
Node* residual_conv_output; Node* residual_conv_output;
if (IsReachable(graph, conv_x_input, conv_y_output)) { if (IsReachable(graph, conv_x_input, conv_y_output)) {
projection_node = conv_x_output; projection_node = conv_x_output;
residual_conv_op = conv_y_op; residual_conv_op = conv_y_op;
residual_conv_input = conv_y_input;
residual_conv_filter = conv_y_filter;
residual_conv_output = conv_y_output; residual_conv_output = conv_y_output;
} else if (IsReachable(graph, conv_y_input, conv_x_output)) { } else if (IsReachable(graph, conv_y_input, conv_x_output)) {
projection_node = conv_y_output; projection_node = conv_y_output;
residual_conv_op = conv_x_op; residual_conv_op = conv_x_op;
residual_conv_input = conv_x_input;
residual_conv_filter = conv_x_filter;
residual_conv_output = conv_x_output; residual_conv_output = conv_x_output;
} else { } else {
return; return;
} }
OpDesc op_desc; auto fuse_relu = HasAttribute<bool>(*residual_conv_op, "fuse_relu");
op_desc.SetType("conv2d"); if (fuse_relu && *fuse_relu) return;
op_desc.SetInput("Input", {residual_conv_input->Name()}); residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
op_desc.SetInput("Filter", {residual_conv_filter->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()});
op_desc.SetInput("ResidualData", {projection_node->Name()});
op_desc.SetOutput("Output", {residual_conv_output->Name()});
auto residual_conv_bias = HasBias(*residual_conv_op, "Bias"); residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
if (residual_conv_bias) { GraphSafeRemoveNodes(graph, {residual_conv_output, elementwise_add_op});
op_desc.SetInput("Bias", {(*residual_conv_bias)->Name()});
}
for (const auto& attr : residual_conv_op->Op()->GetAttrMap()) {
op_desc.SetAttr(attr.first, attr.second);
}
op_desc.SetAttr("fuse_residual_connection", true);
auto fused_conv_op = graph->CreateOpNode(&op_desc); IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out);
IR_NODE_LINK_TO(residual_conv_input, fused_conv_op);
IR_NODE_LINK_TO(residual_conv_filter, fused_conv_op);
IR_NODE_LINK_TO(projection_node, fused_conv_op);
IR_NODE_LINK_TO(fused_conv_op, residual_conv_output);
if (residual_conv_bias) {
IR_NODE_LINK_TO((*residual_conv_bias), fused_conv_op);
}
CorrectGraphEdges(graph, elementwise_add_out, residual_conv_output);
GraphSafeRemoveNodes(
graph, {elementwise_add_out, residual_conv_op, elementwise_add_op});
(*fusion_stats)++; (*fusion_stats)++;
} }
......
...@@ -16,100 +16,25 @@ limitations under the License. */ ...@@ -16,100 +16,25 @@ limitations under the License. */
#include <functional> #include <functional>
#include <vector> #include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/framework/ngraph_bridge.h" #include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/ngraph/ngraph_ops.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/ngraph_helper.h"
#include "ngraph/ngraph.hpp"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
const VariableNameMap& var_map,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = var_map.at(name);
PADDLE_ENFORCE_EQ(var_names.size(), 1,
"op %s name %s expects one associated var", op->Type(),
name);
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
return (*ngb_node_map)[var_names[0]];
} else {
return nullptr;
}
}
static std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, name, op->Inputs(), ngb_node_map);
}
static std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, name, op->Outputs(), ngb_node_map);
}
static void SetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string name,
std::shared_ptr<ngraph::Node> node,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = op->Outputs().at(name);
if (var_names.size() == 1) {
(*ngb_node_map)[var_names[0]] = node;
} else if (var_names.size() == 0) {
(*ngb_node_map)[""] = node;
} else {
PADDLE_THROW("name %s has more than 1 var_names.", name);
}
}
static bool HasOutput(const std::shared_ptr<OperatorBase>& op,
const std::string name) {
auto& outputs = op->Outputs();
if (outputs.find(name) == outputs.end()) return false;
return outputs.at(name).size() > 0;
}
template <typename T>
static void BuildBinaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = GetInputNode(op, "X", ngb_node_map);
auto y = GetInputNode(op, "Y", ngb_node_map);
auto out = std::make_shared<T>(x, y);
SetOutputNode(op, "Out", out, ngb_node_map);
}
template <typename T>
static void BuildUnaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = GetInputNode(op, "X", ngb_node_map);
auto out = std::make_shared<T>(input);
SetOutputNode(op, "Out", out, ngb_node_map);
}
std::map<std::string, std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&, std::function<void(const std::shared_ptr<OperatorBase>&,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = {{"relu", BuildUnaryNode<ngraph::op::Relu>}, NgraphBridge::NG_NODE_MAP = {
{"tanh", BuildUnaryNode<ngraph::op::Tanh>}}; {"mul", paddle::operators::ngraphs::BuildMulNode},
{"mul_grad", paddle::operators::ngraphs::BuildMulGradNode},
{"relu", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Relu>},
{"tanh", paddle::operators::ngraphs::BuildUnaryNode<ngraph::op::Tanh>}};
void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) { void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
auto& op_type = op->Type(); auto& op_type = op->Type();
......
...@@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -110,22 +110,125 @@ class CompileTimeInferShapeContext : public InferShapeContext {
} }
} }
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) override {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) {
return block_.FindVarRecursive(name);
});
return res;
}
DDim GetInputDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> GetInputsDim(const std::string &name) const override {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
bool IsRuntime() const override; bool IsRuntime() const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const override {
return GetVarTypes(Outputs(name));
}
void SetOutputDim(const std::string &name, const DDim &dim) override {
auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}
void SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) override {
auto &names = Outputs(name);
SetDims(names, dims);
}
protected: protected:
proto::VarType::Type GetVarType(const std::string &name) const override; std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(
names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
proto::VarType::Type GetVarType(const std::string &name) const;
DDim GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
DDim GetDim(const std::string &name) const override; std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDim(const std::string &name, const DDim &dim);
void SetDim(const std::string &name, const DDim &dim) override; void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (names[i] == framework::kEmptyVarName) {
continue;
}
SetDim(names[i], dims[i]);
}
}
std::vector<DDim> GetRepeatedDims(const std::string &name) const override; std::vector<DDim> GetRepeatedDims(const std::string &name) const override;
void SetRepeatedDims(const std::string &name, void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override; const std::vector<DDim> &dims) override;
InferShapeVarPtr GetVarPtr(const std::string &name) override;
const OpDesc &op_; const OpDesc &op_;
const BlockDesc &block_; const BlockDesc &block_;
}; };
...@@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs( ...@@ -644,20 +747,6 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
return op_.Output(name); return op_.Output(name);
} }
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims( std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
const std::string &name) const { const std::string &name) const {
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
...@@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType( ...@@ -696,10 +785,5 @@ proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
} }
InferShapeVarPtr CompileTimeInferShapeContext::GetVarPtr(
const std::string &name) {
return block_.FindVarRecursive(name);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -123,6 +123,8 @@ class OpDesc { ...@@ -123,6 +123,8 @@ class OpDesc {
BlockDesc *Block() { return this->block_; } BlockDesc *Block() { return this->block_; }
const BlockDesc *Block() const { return this->block_; }
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames, ...@@ -142,12 +142,14 @@ RuntimeContext::RuntimeContext(const VariableNameMap& innames,
const Scope& scope) { const Scope& scope) {
for (auto& var_name_item : innames) { for (auto& var_name_item : innames) {
std::vector<Variable*>& input_vars = inputs[var_name_item.first]; std::vector<Variable*>& input_vars = inputs[var_name_item.first];
input_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
input_vars.push_back(scope.FindVar(var_name)); input_vars.push_back(scope.FindVar(var_name));
} }
} }
for (auto& var_name_item : outnames) { for (auto& var_name_item : outnames) {
std::vector<Variable*>& output_vars = outputs[var_name_item.first]; std::vector<Variable*>& output_vars = outputs[var_name_item.first];
output_vars.reserve(var_name_item.second.size());
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
output_vars.push_back(scope.FindVar(var_name)); output_vars.push_back(scope.FindVar(var_name));
} }
...@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -556,30 +558,28 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool HasOutput(const std::string& name) const override { bool HasOutput(const std::string& name) const override {
// has only one output // has only one output
const auto& outs = op_.Outputs(); const auto& outs = ctx_.outputs;
auto it = outs.find(name); auto it = outs.find(name);
if (it == outs.end()) { if (it == outs.end()) {
return false; return false;
} }
const auto& out = it->second; const auto& out = it->second;
if (out.size() == 0 || out[0] == kEmptyVarName) { if (out.size() == 0) {
return false; return false;
} }
PADDLE_ENFORCE_EQ(out.size(), 1UL, PADDLE_ENFORCE_EQ(out.size(), 1UL,
"Output %s should not have more than one outputs", name); "Output %s should not have more than one outputs", name);
return scope_.FindVar(out[0]) != nullptr; return out[0] != nullptr;
} }
bool HasInputs(const std::string& name) const override { bool HasInputs(const std::string& name) const override {
if (!op_.HasInputs(name)) { const auto& ins = ctx_.inputs;
return false; auto it = ins.find(name);
} if (it == ins.end() || it->second.empty()) {
auto inputs = op_.Inputs(name);
if (inputs.empty()) {
return false; return false;
} }
for (auto& input : inputs) { for (auto& input : it->second) {
if (scope_.FindVar(input) == nullptr) { if (input == nullptr) {
return false; return false;
} }
} }
...@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -587,15 +587,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
} }
bool HasOutputs(const std::string& name) const override { bool HasOutputs(const std::string& name) const override {
if (!op_.HasOutputs(name)) { const auto& outs = ctx_.outputs;
return false; auto it = outs.find(name);
} if (it == outs.end() || it->second.empty()) {
auto outputs = op_.Outputs(name);
if (outputs.empty()) {
return false; return false;
} }
for (auto& output : outputs) { for (auto& output : it->second) {
if (scope_.FindVar(output) == nullptr) { if (output == nullptr) {
return false; return false;
} }
} }
...@@ -616,16 +614,18 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -616,16 +614,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareDim(const std::string& in, const std::string& out, size_t i = 0, void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override { size_t j = 0) override {
PADDLE_ENFORCE_LT(i, Inputs(in).size()); auto in_it = ctx_.inputs.find(in);
PADDLE_ENFORCE_LT(j, Outputs(out).size()); auto out_it = ctx_.outputs.find(out);
const std::string& input_n = Inputs(in)[i]; PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
const std::string& output_n = Outputs(out)[j]; "Inputs %s should have %llu argument", in, i);
PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
"Outputs %s should have %llu argument", out, j);
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
Variable* in_var = scope_.FindVar(input_n);
Variable* out_var = scope_.FindVar(output_n);
PADDLE_ENFORCE(in_var->Type() == out_var->Type(), PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
"The type of %s and %s is not the same.", output_n, "The type of %s and %s is not the same.", in, out);
GetDim(input_n));
if (in_var->IsType<framework::SelectedRows>()) { if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>(); auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
...@@ -646,13 +646,16 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -646,13 +646,16 @@ class RuntimeInferShapeContext : public InferShapeContext {
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0, void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override { size_t j = 0) const override {
const std::vector<std::string>& inputs = Inputs(in); auto in_it = ctx_.inputs.find(in);
const std::vector<std::string>& outputs = Outputs(out); auto out_it = ctx_.outputs.find(out);
PADDLE_ENFORCE_LT(i, inputs.size()); PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
PADDLE_ENFORCE_LT(j, outputs.size()); "Inputs %s should have %llu argument", in, i);
Variable* in_var = scope_.FindVar(inputs.at(i)); PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
"Outputs %s should have %llu argument", out, j);
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<LoDTensor>()) return; if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = scope_.FindVar(outputs.at(j)); Variable* out_var = out_it->second.at(j);
PADDLE_ENFORCE(out_var->IsType<LoDTensor>(), PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
"The %d-th output of Output(%s) must be LoDTensor.", j, out); "The %d-th output of Output(%s) must be LoDTensor.", j, out);
auto in_tensor = in_var->Get<LoDTensor>(); auto in_tensor = in_var->Get<LoDTensor>();
...@@ -687,9 +690,64 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -687,9 +690,64 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = InputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override {
const std::vector<Variable*>& vars = OutputVars(name);
std::vector<InferShapeVarPtr> res;
res.reserve(vars.size());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
}
DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(vars.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, vars.size());
return this->GetDim(vars[0]);
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
}
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
return GetVarTypes(InputVars(name));
}
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(vars.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, vars.size());
SetDim(vars[0], dim);
}
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name);
SetDims(vars, dims);
}
protected: protected:
DDim GetDim(const std::string& name) const override { DDim GetDim(Variable* var) const {
Variable* var = scope_.FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var); PADDLE_ENFORCE_NOT_NULL(var);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
return var->Get<LoDTensor>().dims(); return var->Get<LoDTensor>().dims();
...@@ -697,25 +755,44 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -697,25 +755,44 @@ class RuntimeInferShapeContext : public InferShapeContext {
return var->Get<SelectedRows>().GetCompleteDims(); return var->Get<SelectedRows>().GetCompleteDims();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's " "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
"type_id is %s.", "type_id is %s.",
name, var->Type().name()); var->Type().name());
} }
} }
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
ret.reserve(vars.size());
std::transform(vars.begin(), vars.end(), std::back_inserter(ret),
[this](Variable* var) { return this->GetDim(var); });
return ret;
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override { std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW("Only compile time support this method");
} }
void SetDim(const std::string& name, const DDim& dim) override { void SetDim(Variable* var, const DDim& dim) {
Variable* var = scope_.FindVar(name);
if (var->IsType<LoDTensor>()) { if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim); var->GetMutable<LoDTensor>()->Resize(dim);
} else if (var->IsType<SelectedRows>()) { } else if (var->IsType<SelectedRows>()) {
var->GetMutable<SelectedRows>()->set_height(dim[0]); var->GetMutable<SelectedRows>()->set_height(dim[0]);
} else { } else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.", PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name()); var->Type().name());
}
}
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
size_t length = vars.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (vars[i] == nullptr) {
continue;
}
SetDim(vars[i], dims[i]);
} }
} }
...@@ -724,16 +801,36 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -724,16 +801,36 @@ class RuntimeInferShapeContext : public InferShapeContext {
PADDLE_THROW("Only compile time support this method"); PADDLE_THROW("Only compile time support this method");
} }
proto::VarType::Type GetVarType(const std::string& name) const override { std::vector<proto::VarType::Type> GetVarTypes(
auto* var = scope_.FindVar(name); const std::vector<Variable*>& vars) const {
return ToVarType(var->Type()); std::vector<proto::VarType::Type> retv;
retv.resize(vars.size());
std::transform(vars.begin(), vars.end(), retv.begin(),
std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
this, std::placeholders::_1));
return retv;
} }
InferShapeVarPtr GetVarPtr(const std::string& name) override { proto::VarType::Type GetVarType(Variable* var) const {
return scope_.FindVar(name); return ToVarType(var->Type());
} }
private: private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
PADDLE_ENFORCE(it != ctx_.inputs.end(),
"Operator %s does not have the input %s.", op_.Type(), name);
return it->second;
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE(it != ctx_.outputs.end(),
"Operator %s does not have the outputs %s.", op_.Type(),
name);
return it->second;
}
const OperatorBase& op_; const OperatorBase& op_;
const Scope& scope_; const Scope& scope_;
const RuntimeContext& ctx_; const RuntimeContext& ctx_;
...@@ -864,8 +961,7 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -864,8 +961,7 @@ Scope* OperatorWithKernel::PrepareData(
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto& var_name = var_name_item.second[i]; auto& var_name = var_name_item.second[i];
auto* var = scope.FindVar(var_name); auto* var = input_vars[i];
input_vars[i] = var;
// Only tensor can be tranfer to another device. // Only tensor can be tranfer to another device.
if (var == nullptr || !VarIsTensor(*var)) { if (var == nullptr || !VarIsTensor(*var)) {
......
...@@ -190,7 +190,6 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() { ...@@ -190,7 +190,6 @@ std::vector<Scope *> &ParallelExecutor::GetLocalScopes() {
ParallelExecutor::ParallelExecutor( ParallelExecutor::ParallelExecutor(
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name, const ProgramDesc &main_program, const std::string &loss_var_name,
Scope *scope, const std::vector<Scope *> &local_scopes, Scope *scope, const std::vector<Scope *> &local_scopes,
...@@ -209,7 +208,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -209,7 +208,7 @@ ParallelExecutor::ParallelExecutor(
"the number of places must be greater than 1."); "the number of places must be greater than 1.");
} }
// Step 1. Bcast the params to devs. // Step 1. Bcast the bcast_vars to devs.
// Create local scopes // Create local scopes
if (local_scopes.empty()) { if (local_scopes.empty()) {
member_->own_local_scope_ = true; member_->own_local_scope_ = true;
...@@ -249,12 +248,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -249,12 +248,12 @@ ParallelExecutor::ParallelExecutor(
// ncclOp // ncclOp
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, params, main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->local_scopes_, member_->use_cuda_, member_->nccl_ctxs_.get()); member_->use_cuda_, member_->nccl_ctxs_.get());
#else #else
std::unique_ptr<ir::Graph> graph = std::unique_ptr<ir::Graph> graph =
build_strategy.Apply(main_program, member_->places_, loss_var_name, build_strategy.Apply(main_program, member_->places_, loss_var_name,
params, member_->local_scopes_, member_->use_cuda_); member_->local_scopes_, member_->use_cuda_);
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
......
...@@ -41,7 +41,6 @@ class ParallelExecutor { ...@@ -41,7 +41,6 @@ class ParallelExecutor {
public: public:
explicit ParallelExecutor(const std::vector<platform::Place> &places, explicit ParallelExecutor(const std::vector<platform::Place> &places,
const std::unordered_set<std::string> &params,
const std::unordered_set<std::string> &bcast_vars, const std::unordered_set<std::string> &bcast_vars,
const ProgramDesc &main_program, const ProgramDesc &main_program,
const std::string &loss_var_name, Scope *scope, const std::string &loss_var_name, Scope *scope,
......
...@@ -22,20 +22,6 @@ limitations under the License. */ ...@@ -22,20 +22,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DDim InferShapeContext::GetInputDim(const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Input(%s) should hold one element, but now it holds %d",
name, arg_names.size());
return this->GetDim(arg_names[0]);
}
std::vector<DDim> InferShapeContext::GetInputsDim(
const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name);
return GetDims(arg_names);
}
std::vector<DDim> InferShapeContext::GetReaderDims( std::vector<DDim> InferShapeContext::GetReaderDims(
const std::string &name) const { const std::string &name) const {
const std::vector<std::string> &arg_names = Inputs(name); const std::vector<std::string> &arg_names = Inputs(name);
...@@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims( ...@@ -46,26 +32,6 @@ std::vector<DDim> InferShapeContext::GetReaderDims(
return this->GetRepeatedDims(arg_names[0]); return this->GetRepeatedDims(arg_names[0]);
} }
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
return this->GetDim(names[idx]);
}
void InferShapeContext::SetOutputDim(const std::string &name, const DDim &dim) {
auto &arg_names = Outputs(name);
PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
"Output(%s) should hold one element, but now it holds %d",
name, arg_names.size());
SetDim(arg_names[0], dim);
}
void InferShapeContext::SetOutputsDim(const std::string &name,
const std::vector<DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}
void InferShapeContext::SetReaderDims(const std::string &name, void InferShapeContext::SetReaderDims(const std::string &name,
const std::vector<DDim> &dims) { const std::vector<DDim> &dims) {
const std::vector<std::string> &arg_names = Outputs(name); const std::vector<std::string> &arg_names = Outputs(name);
...@@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name, ...@@ -76,69 +42,5 @@ void InferShapeContext::SetReaderDims(const std::string &name,
return this->SetRepeatedDims(arg_names[0], dims); return this->SetRepeatedDims(arg_names[0], dims);
} }
std::vector<InferShapeVarPtr> InferShapeContext::GetInputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Inputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<InferShapeVarPtr> InferShapeContext::GetOutputVarPtrs(
const std::string &name) {
const std::vector<std::string> arg_names = Outputs(name);
std::vector<InferShapeVarPtr> res;
res.reserve(arg_names.size());
std::transform(
arg_names.begin(), arg_names.end(), std::back_inserter(res),
[this](const std::string &name) { return this->GetVarPtr(name); });
return res;
}
std::vector<DDim> InferShapeContext::GetDims(
const std::vector<std::string> &names) const {
std::vector<DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void InferShapeContext::SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
if (names[i] == framework::kEmptyVarName) {
continue;
}
SetDim(names[i], dims[i]);
}
}
std::vector<proto::VarType::Type> InferShapeContext::GetInputsVarType(
const std::string &name) const {
return GetVarTypes(Inputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetOutputsVarType(
const std::string &name) const {
return GetVarTypes(Outputs(name));
}
std::vector<proto::VarType::Type> InferShapeContext::GetVarTypes(
const std::vector<std::string> &names) const {
std::vector<proto::VarType::Type> retv;
retv.resize(names.size());
std::transform(names.begin(), names.end(), retv.begin(),
std::bind(std::mem_fn(&InferShapeContext::GetVarType), this,
std::placeholders::_1));
return retv;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -25,6 +25,8 @@ limitations under the License. */ ...@@ -25,6 +25,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>; using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
class InferShapeContext { class InferShapeContext {
...@@ -33,22 +35,23 @@ class InferShapeContext { ...@@ -33,22 +35,23 @@ class InferShapeContext {
virtual bool HasInput(const std::string &name) const = 0; virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0;
std::vector<proto::VarType::Type> GetInputsVarType( virtual std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const; const std::string &name) const = 0;
std::vector<proto::VarType::Type> GetOutputsVarType( virtual std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const = 0;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
DDim GetInputDim(const std::string &name) const; virtual DDim GetInputDim(const std::string &name) const = 0;
std::vector<DDim> GetInputsDim(const std::string &name) const; virtual std::vector<DDim> GetInputsDim(const std::string &name) const = 0;
std::vector<DDim> GetReaderDims(const std::string &name) const; virtual std::vector<DDim> GetReaderDims(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const;
void SetOutputDim(const std::string &name, const DDim &dim); virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name, const std::vector<DDim> &dims); virtual void SetOutputsDim(const std::string &name,
void SetReaderDims(const std::string &name, const std::vector<DDim> &dims); const std::vector<DDim> &dims) = 0;
virtual void SetReaderDims(const std::string &name,
const std::vector<DDim> &dims);
virtual AttrReader Attrs() const = 0; virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs( virtual const std::vector<std::string> &Inputs(
...@@ -67,27 +70,15 @@ class InferShapeContext { ...@@ -67,27 +70,15 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0; virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name); virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
std::vector<InferShapeVarPtr> GetOutputVarPtrs(const std::string &name); const std::string &name) = 0;
virtual InferShapeVarPtr GetVarPtr(const std::string &name) = 0; virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string &name) = 0;
// Note: In while op, we need this to be public
void SetDims(const std::vector<std::string> &names,
const std::vector<DDim> &dims);
protected: protected:
virtual DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const DDim &dim) = 0;
virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0; virtual std::vector<DDim> GetRepeatedDims(const std::string &name) const = 0;
virtual void SetRepeatedDims(const std::string &name, virtual void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) = 0; const std::vector<DDim> &dims) = 0;
std::vector<DDim> GetDims(const std::vector<std::string> &names) const;
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<std::string> &names) const;
virtual proto::VarType::Type GetVarType(const std::string &name) const = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -188,11 +188,13 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) { ...@@ -188,11 +188,13 @@ std::vector<Variable*> OpBase::ApplyGrad(framework::Scope* scope) {
std::vector<Variable*> ret; std::vector<Variable*> ret;
for (size_t i = 0; i < input_vars_->size(); ++i) { for (size_t i = 0; i < input_vars_->size(); ++i) {
bool found = false; bool found = false;
VarBase* origin_var = (*input_vars_)[i];
for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) { for (const std::string& outvar : grad_op_desc_->OutputArgumentNames()) {
Variable* var = scope->FindVar(outvar); Variable* var = scope->FindVar(outvar);
VarBase* origin_var = (*input_vars_)[i];
std::string orig_var = grad_to_var_->at(outvar); std::string orig_var = grad_to_var_->at(outvar);
PADDLE_ENFORCE(origin_var->var_desc_->Name() == orig_var); if (origin_var->var_desc_->Name() != orig_var) {
continue;
}
VLOG(3) << "apply grad " << outvar << " with origin " << orig_var; VLOG(3) << "apply grad " << outvar << " with origin " << orig_var;
origin_var->ApplyGrad(scope, var); origin_var->ApplyGrad(scope, var);
found = true; found = true;
......
...@@ -43,9 +43,12 @@ void CreateGradOp(const framework::OpDesc& op_desc, ...@@ -43,9 +43,12 @@ void CreateGradOp(const framework::OpDesc& op_desc,
class Tracer { class Tracer {
public: public:
explicit Tracer(framework::BlockDesc* root_block) : root_block_(root_block) { explicit Tracer(framework::BlockDesc* root_block,
framework::BlockDesc* startup_block)
: root_block_(root_block), startup_block_(startup_block) {
root_scope_ = new framework::Scope(); root_scope_ = new framework::Scope();
scopes_[root_block_] = root_scope_; scopes_[root_block_] = root_scope_;
scopes_[startup_block_] = root_scope_;
} }
virtual ~Tracer() { delete root_scope_; } virtual ~Tracer() { delete root_scope_; }
...@@ -80,6 +83,8 @@ class Tracer { ...@@ -80,6 +83,8 @@ class Tracer {
} else { } else {
op->pre_ops_->push_back(nullptr); op->pre_ops_->push_back(nullptr);
} }
VLOG(3) << "input vname " << vname << " "
<< var->Get<framework::LoDTensor>().dims().size();
} }
*op->output_vars_ = outputs; *op->output_vars_ = outputs;
...@@ -98,12 +103,19 @@ class Tracer { ...@@ -98,12 +103,19 @@ class Tracer {
outputs[i]->pre_op_ = op; outputs[i]->pre_op_ = op;
outputs[i]->pre_op_out_idx_ = i; outputs[i]->pre_op_out_idx_ = i;
} }
VLOG(3) << "tracer running " << op_desc->Type();
op_base->Run(*scope, platform::CPUPlace()); op_base->Run(*scope, platform::CPUPlace());
framework::OpDesc* grad_op_desc; if (block == startup_block_) {
auto grad_to_var = new std::unordered_map<std::string, std::string>(); op->grad_op_desc_ = nullptr;
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var); op->grad_to_var_ = nullptr;
op->grad_op_desc_ = grad_op_desc; } else {
op->grad_to_var_ = grad_to_var; framework::OpDesc* grad_op_desc;
auto grad_to_var = new std::unordered_map<std::string, std::string>();
CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
op->grad_op_desc_ = grad_op_desc;
op->grad_to_var_ = grad_to_var;
}
op->block_ = block; op->block_ = block;
} }
...@@ -121,6 +133,7 @@ class Tracer { ...@@ -121,6 +133,7 @@ class Tracer {
private: private:
std::map<framework::BlockDesc*, framework::Scope*> scopes_; std::map<framework::BlockDesc*, framework::Scope*> scopes_;
framework::BlockDesc* root_block_; framework::BlockDesc* root_block_;
framework::BlockDesc* startup_block_;
framework::Scope* root_scope_; framework::Scope* root_scope_;
}; };
......
...@@ -254,5 +254,16 @@ TEST(Analyzer_dam, compare) { compare(); } ...@@ -254,5 +254,16 @@ TEST(Analyzer_dam, compare) { compare(); }
TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif #endif
// Compare Deterministic result
TEST(Analyzer_dam, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -180,6 +180,17 @@ TEST(Analyzer_LAC, compare) { ...@@ -180,6 +180,17 @@ TEST(Analyzer_LAC, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_LAC, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -179,5 +179,16 @@ TEST(Analyzer_Chinese_ner, compare) { ...@@ -179,5 +179,16 @@ TEST(Analyzer_Chinese_ner, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_Chinese_ner, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -85,6 +85,17 @@ TEST(Analyzer_resnet50, compare) { compare(); } ...@@ -85,6 +85,17 @@ TEST(Analyzer_resnet50, compare) { compare(); }
TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_resnet50, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif #endif
// Compare Deterministic result
TEST(Analyzer_resnet50, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -265,6 +265,17 @@ TEST(Analyzer_rnn1, compare) { ...@@ -265,6 +265,17 @@ TEST(Analyzer_rnn1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_rnn1, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
// Test Multi-Thread. // Test Multi-Thread.
TEST(Analyzer_rnn1, multi_thread) { TEST(Analyzer_rnn1, multi_thread) {
contrib::AnalysisConfig cfg; contrib::AnalysisConfig cfg;
......
...@@ -158,5 +158,16 @@ TEST(Analyzer_rnn2, compare) { ...@@ -158,5 +158,16 @@ TEST(Analyzer_rnn2, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_rnn2, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -204,5 +204,16 @@ TEST(Analyzer_seq_conv1, compare) { ...@@ -204,5 +204,16 @@ TEST(Analyzer_seq_conv1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_seq_conv1, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -106,6 +106,17 @@ TEST(Analyzer_Text_Classification, compare) { ...@@ -106,6 +106,17 @@ TEST(Analyzer_Text_Classification, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Compare Deterministic result
TEST(Analyzer_Text_Classification, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) { TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
......
...@@ -145,6 +145,17 @@ TEST(Analyzer_vis, compare) { compare(); } ...@@ -145,6 +145,17 @@ TEST(Analyzer_vis, compare) { compare(); }
TEST(Analyzer_vis, compare_mkldnn) { compare(true /* use_mkldnn */); } TEST(Analyzer_vis, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif #endif
// Compare Deterministic result
TEST(Analyzer_vis, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);
std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
input_slots_all);
}
} // namespace analysis } // namespace analysis
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -45,6 +45,7 @@ DEFINE_bool(use_analysis, true, ...@@ -45,6 +45,7 @@ DEFINE_bool(use_analysis, true,
"Running the inference program in analysis mode."); "Running the inference program in analysis mode.");
DEFINE_bool(record_benchmark, false, DEFINE_bool(record_benchmark, false,
"Record benchmark after profiling the model"); "Record benchmark after profiling the model");
DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
DECLARE_bool(profile); DECLARE_bool(profile);
DECLARE_int32(paddle_num_threads); DECLARE_int32(paddle_num_threads);
...@@ -85,7 +86,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -85,7 +86,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
float *pdata = static_cast<float *>(out.data.data()); float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data()); float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) { for (size_t j = 0; j < size; ++j) {
EXPECT_NEAR(pdata_ref[j], pdata[j], 1e-3); EXPECT_NEAR(pdata_ref[j], pdata[j], FLAGS_accuracy);
} }
break; break;
} }
...@@ -283,6 +284,26 @@ void TestPrediction(const PaddlePredictor::Config *config, ...@@ -283,6 +284,26 @@ void TestPrediction(const PaddlePredictor::Config *config,
} }
} }
void CompareDeterministic(
const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) {
int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat;
auto predictor = CreateTestPredictor(config, FLAGS_use_analysis);
// warmup run
std::vector<PaddleTensor> warmup_outputs, outputs;
predictor->Run(inputs[0], &warmup_outputs, batch_size);
// run num_times to Compare Deterministic Result.
for (int i = 0; i < num_times; i++) {
for (size_t j = 0; j < inputs.size(); j++) {
predictor->Run(inputs[j], &outputs, batch_size);
CompareResult(outputs, warmup_outputs);
}
}
}
void CompareNativeAndAnalysis( void CompareNativeAndAnalysis(
const PaddlePredictor::Config *config, const PaddlePredictor::Config *config,
const std::vector<std::vector<PaddleTensor>> &inputs) { const std::vector<std::vector<PaddleTensor>> &inputs) {
......
...@@ -16,6 +16,7 @@ add_subdirectory(metrics) ...@@ -16,6 +16,7 @@ add_subdirectory(metrics)
add_subdirectory(optimizers) add_subdirectory(optimizers)
add_subdirectory(reduce_ops) add_subdirectory(reduce_ops)
add_subdirectory(sequence_ops) add_subdirectory(sequence_ops)
add_subdirectory(jit)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(distributed) add_subdirectory(distributed)
...@@ -42,8 +43,7 @@ if (WITH_DISTRIBUTE) ...@@ -42,8 +43,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif() endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above # warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32) if (WITH_GPU AND NOT WIN32)
...@@ -65,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) ...@@ -65,7 +65,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel concat_and_split cross_entropy softmax vol2col im2col sampler) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
if (WITH_GPU) if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
...@@ -92,4 +92,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) ...@@ -92,4 +92,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
...@@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ...@@ -399,26 +399,41 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(kOutputs); ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs)); ctx->HasInputs(framework::GradVarName(kOutputs));
auto p_names = ctx->Inputs(kX);
auto pg_ig_names = ctx->Outputs(kXGRAD); auto pg_ig_names = ctx->Outputs(kXGRAD);
auto var_types = ctx->GetInputsVarType(kX); std::vector<framework::InferShapeVarPtr> in_var_ptrs =
std::vector<std::string> names_to_set; ctx->GetInputVarPtrs(kX);
std::vector<framework::DDim> dims_to_set; std::vector<framework::InferShapeVarPtr> out_var_ptrs =
for (size_t i = 0; i < p_names.size(); ++i) { ctx->GetOutputVarPtrs(kXGRAD);
PADDLE_ENFORCE(in_var_ptrs.size() == out_var_ptrs.size());
for (size_t i = 0; i < in_var_ptrs.size(); ++i) {
if (pg_ig_names[i] == framework::kEmptyVarName) { if (pg_ig_names[i] == framework::kEmptyVarName) {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kX, i); if (ctx->IsRuntime()) {
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { framework::Variable *in_var =
names_to_set.push_back(pg_ig_names[i]); boost::get<framework::Variable *>(in_var_ptrs[i]);
dims_to_set.push_back(dims); framework::Variable *out_var =
} else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) { boost::get<framework::Variable *>(out_var_ptrs[i]);
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_ig_names[i]); auto type = framework::ToVarType(in_var->Type());
dims_to_set.push_back(dims); if (type == framework::proto::VarType::LOD_TENSOR) {
out_var->GetMutable<LoDTensor>()->Resize(
in_var->Get<framework::LoDTensor>().dims());
} else if (type == framework::proto::VarType::SELECTED_ROWS) {
out_var->GetMutable<framework::SelectedRows>()->set_height(
in_var->Get<framework::SelectedRows>().GetCompleteDims()[0]);
} else if (type == framework::proto::VarType::LOD_TENSOR_ARRAY) {
PADDLE_THROW("WhileGradOp doesn't support type %d",
static_cast<int>(type));
}
} else {
framework::VarDesc *in_var =
boost::get<framework::VarDesc *>(in_var_ptrs[i]);
boost::get<framework::VarDesc *>(out_var_ptrs[i])
->SetShape(in_var->GetShape());
} }
} }
ctx->SetDims(names_to_set, dims_to_set);
} }
}; };
......
...@@ -155,11 +155,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -155,11 +155,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
if (is_conv3d) { weights_format = mkldnn::memory::format::any;
chosen_memory_format = // Check the format for user's special output
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); if (chosen_memory_format != mkldnn::memory::format::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
} }
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
...@@ -435,11 +438,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -435,11 +438,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
if (is_conv3d) { weights_format = mkldnn::memory::format::any;
chosen_memory_format = // Check the format for user's special output
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); if (chosen_memory_format != mkldnn::memory::format::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
} }
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <limits> #include <limits>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {
...@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> { ...@@ -82,10 +82,9 @@ class CRFDecodingOpKernel : public framework::OpKernel<T> {
Tensor track; Tensor track;
int* track_value = int* track_value =
track.mutable_data<int>(emission_dims, platform::CPUPlace()); track.mutable_data<int>(emission_dims, platform::CPUPlace());
const auto& ker = math::jitkernel::KernelPool::Instance() auto ker = jit::Get<jit::kCRFDecoding, jit::CRFDecodingTuples<T>,
.template Get<math::jitkernel::CRFDecodeKernel<T>>( platform::CPUPlace>(tag_num);
static_cast<int>(tag_num)); ker(static_cast<int>(seq_len), x, w, alpha_value, track_value, tag_num);
ker->Compute(static_cast<int>(seq_len), x, w, alpha_value, track_value);
T max_score = -std::numeric_limits<T>::max(); T max_score = -std::numeric_limits<T>::max();
int max_i = 0; int max_i = 0;
for (size_t i = 0; i < tag_num; ++i) { for (size_t i = 0; i < tag_num; ++i) {
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <sys/time.h> #include <sys/time.h>
#include <limits>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -31,7 +32,12 @@ namespace distributed { ...@@ -31,7 +32,12 @@ namespace distributed {
class IOBufWriter { class IOBufWriter {
public: public:
static void Append(butil::IOBuf* iobuf, int k, const char* v, int64_t vlen) { static void Append(const std::string& varname, butil::IOBuf* iobuf, int k,
const char* v, int64_t vlen) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
}
iobuf->append(reinterpret_cast<char*>(&k), 4); iobuf->append(reinterpret_cast<char*>(&k), 4);
iobuf->append(reinterpret_cast<char*>(&vlen), 8); iobuf->append(reinterpret_cast<char*>(&vlen), 8);
iobuf->append(v, vlen); iobuf->append(v, vlen);
...@@ -87,6 +93,10 @@ class IOBufWriter { ...@@ -87,6 +93,10 @@ class IOBufWriter {
int k, const char* v, int64_t vlen, int k, const char* v, int64_t vlen,
bool in_cuda_pinned, void (*destroy)(void*), bool in_cuda_pinned, void (*destroy)(void*),
void* user_data) { void* user_data) {
if (vlen >= std::numeric_limits<int>::max() || vlen < 0) {
LOG(FATAL) << "AppendZeroCopy varname:" << varname << ", vlen:" << vlen;
}
#ifdef PADDLE_WITH_BRPC_RDMA #ifdef PADDLE_WITH_BRPC_RDMA
IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned, IOBufWriter::AppendRdmaZeroCopy(varname, iobuf, k, v, vlen, in_cuda_pinned,
destroy, user_data); destroy, user_data);
...@@ -134,7 +144,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -134,7 +144,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
request->set_type(::sendrecv::NCCL_ID); request->set_type(::sendrecv::NCCL_ID);
const ncclUniqueId& uid = var->Get<ncclUniqueId>(); const ncclUniqueId& uid = var->Get<ncclUniqueId>();
// TODO(gongwb): use append_zero to avoid data copy. // TODO(gongwb): use append_zero to avoid data copy.
IOBufWriter::Append(iobuf, IOBufWriter::Append(name, iobuf,
sendrecv::VariableMessage::kSerializedFieldNumber, sendrecv::VariableMessage::kSerializedFieldNumber,
uid.internal, NCCL_UNIQUE_ID_BYTES); uid.internal, NCCL_UNIQUE_ID_BYTES);
return; return;
...@@ -149,7 +159,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -149,7 +159,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
// FIXME(gongwb): it seems that can use zero copy. // FIXME(gongwb): it seems that can use zero copy.
if (var_is_not_stable) { if (var_is_not_stable) {
IOBufWriter::Append( IOBufWriter::Append(
iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber,
static_cast<const char*>(payload->ptr()), payload->memory_size()); static_cast<const char*>(payload->ptr()), payload->memory_size());
} else { } else {
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
...@@ -171,10 +181,11 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, ...@@ -171,10 +181,11 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
size_t rows_memory_size = PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name());
slr->rows().size() * framework::SizeOfType(typeid(int64_t)); size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
IOBufWriter::Append(iobuf, ::sendrecv::VariableMessage::kRowsFieldNumber, IOBufWriter::Append(name, iobuf,
::sendrecv::VariableMessage::kRowsFieldNumber,
reinterpret_cast<const char*>(slr->rows().data()), reinterpret_cast<const char*>(slr->rows().data()),
static_cast<int64_t>(rows_memory_size)); static_cast<int64_t>(rows_memory_size));
} }
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <stdlib.h>
#include <limits> #include <limits>
#include "glog/logging.h" // For VLOG #include "glog/logging.h" // For VLOG
...@@ -420,7 +421,15 @@ void GRPCClient::Proceed() { ...@@ -420,7 +421,15 @@ void GRPCClient::Proceed() {
sync_cond_.notify_all(); sync_cond_.notify_all();
} }
} }
VLOG(3) << "GRPCClient Proceed end";
// Last log message
// Avoid using VLOG() and LOG(): in the destructor of google::LogMessage() a
// static Mutex log_mutex is used for synchronization, which might have been
// destructed at this moment.
if (FLAGS_v >= 3) {
std::string msg("GRPCClient Proceed end");
fwrite(msg.c_str(), msg.length(), 1, stdout);
}
} }
std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> GRPCClient::GetChannel(const std::string& ep) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <limits>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "google/protobuf/io/coded_stream.h" #include "google/protobuf/io/coded_stream.h"
...@@ -102,6 +103,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -102,6 +103,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->memory_size()); payload->memory_size());
if (payload->memory_size() >= std::numeric_limits<int>::max()) {
LOG(FATAL) << "AppendZeroCopy varname:" << name
<< ", vlen:" << payload->memory_size();
}
// steal reference of tensor data // steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer int num_slices = 2; // only SelectedRows have rows buffer
...@@ -115,7 +120,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -115,7 +120,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
if (var->IsType<framework::SelectedRows>()) { if (var->IsType<framework::SelectedRows>()) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
ProtoEncodeHelper e2(static_cast<char*>(buf), 128); ProtoEncodeHelper e2(static_cast<char*>(buf), 128);
PADDLE_ENFORCE(VectorElemName(slr->rows()) == typeid(int64_t).name());
size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); size_t rows_memory_size = slr->rows().size() * sizeof(int64_t);
e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size);
slices[2] = ::grpc::Slice(e2.size()); slices[2] = ::grpc::Slice(e2.size());
memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size()); memcpy(const_cast<uint8_t*>(slices[2].begin()), e2.data(), e2.size());
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <typeindex>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -23,9 +24,8 @@ limitations under the License. */ ...@@ -23,9 +24,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/operators/distributed/send_recv.pb.h" #include "paddle/fluid/operators/distributed/send_recv.pb.h"
#include "paddle/fluid/platform/port.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -83,6 +83,11 @@ inline framework::proto::VarType::Type ToVarType( ...@@ -83,6 +83,11 @@ inline framework::proto::VarType::Type ToVarType(
} }
} }
template <template <typename> class T, typename Elem>
std::string VectorElemName(const T<Elem>& arg) {
return typeid(Elem).name();
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -118,7 +118,7 @@ bool VariableResponse::CopyLodTensorData( ...@@ -118,7 +118,7 @@ bool VariableResponse::CopyLodTensorData(
VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() VLOG(6) << "Tensor.memory_size = " << tensor->memory_size()
<< ", Buffer Size = " << length; << ", Buffer Size = " << length;
PADDLE_ENFORCE_EQ(tensor->memory_size(), length); PADDLE_ENFORCE_EQ(tensor->memory_size(), static_cast<unsigned int>(length));
return ReadRaw(input, ctx, tensor->place(), tensor_data, length); return ReadRaw(input, ctx, tensor->place(), tensor_data, length);
} }
......
...@@ -17,8 +17,8 @@ limitations under the License. */ ...@@ -17,8 +17,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#ifdef PADDLE_WITH_XBYAK #ifdef PADDLE_WITH_XBYAK
#include "xbyak/xbyak.h" #include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h" #include "xbyak/xbyak_util.h"
...@@ -109,10 +109,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -109,10 +109,8 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
constexpr int simd_width = 16; constexpr int simd_width = 16;
int C = c / simd_width; int C = c / simd_width;
const auto& multiply = auto multiply = jit::Get<jit::kNCHW16CMulNC, jit::NCHW16CMulNCTuples<T>,
math::jitkernel::KernelPool::Instance() platform::CPUPlace>(0);
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int ni = 0; ni < n; ni++) { for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) { for (int ci = 0; ci < C; ci++) {
...@@ -123,7 +121,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -123,7 +121,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto ptr_z = auto ptr_z =
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w); multiply(ptr_x, ptr_y, ptr_z, h, w);
} }
} }
} }
......
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h" #include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -182,27 +182,29 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int total_T = x_dims[0]; \ const int total_T = x_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const math::jitkernel::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("activation")); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
math::jitkernel::gru_t one_step; \ jit::gru_t one_step; \
const auto& ker = \ auto ComputeH1 = \
math::jitkernel::KernelPool::Instance() \ jit::Get<jit::kGRUH1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
.template Get<math::jitkernel::GRUKernel<T>, \ auto ComputeHtPart1 = \
const math::jitkernel::gru_attr_t&>(attr); \ jit::Get<jit::kGRUHtPart1, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \ auto ComputeHtPart2 = \
const T* wx_data = wx->data<T>(); \ jit::Get<jit::kGRUHtPart2, jit::GRUTuples<T>, platform::CPUPlace>(attr); \
const T* wh_data = wh->data<T>(); \ const T* x_data = x->data<T>(); \
auto place = ctx.GetPlace(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
auto place = ctx.GetPlace(); \
T* xx_data = xx->mutable_data<T>(place) T* xx_data = xx->mutable_data<T>(place)
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
...@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -241,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} else { } else {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -254,12 +256,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data; one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -323,7 +325,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -351,7 +353,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
...@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -369,7 +371,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -235,31 +235,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D4 = wh_dims[1] const int D4 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
/* diagonal weight*/ \ /* diagonal weight*/ \
const T* wp_data = bias->data<T>() + D4; \ const T* wp_data = bias->data<T>() + D4; \
/* for peephole only*/ \ /* for peephole only*/ \
T* checked_cell_data = nullptr; \ T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \ auto place = ctx.GetPlace(); \
if (use_peepholes) { \ if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \ /* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/ \
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const math::jitkernel::lstm_attr_t attr( \ const jit::lstm_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("candidate_activation"), \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \ jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
math::jitkernel::lstm_t one_step; \ use_peepholes); \
one_step.wp = wp_data; \ jit::lstm_t one_step; \
one_step.checked = checked_cell_data; \ one_step.wp = wp_data; \
const auto& ker = \ one_step.checked = checked_cell_data; \
math::jitkernel::KernelPool::Instance() \ auto ComputeC1H1 = \
.template Get<math::jitkernel::LSTMKernel<T>, \ jit::Get<jit::kLSTMC1H1, jit::LSTMTuples<T>, platform::CPUPlace>(attr); \
const math::jitkernel::lstm_attr_t&>(attr) auto ComputeCtHt = \
jit::Get<jit::kLSTMCtHt, jit::LSTMTuples<T>, platform::CPUPlace>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -305,7 +306,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
tstart = 1; tstart = 1;
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
...@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -321,7 +322,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = prev_c_data; one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
prev_c_data = c_out_data; prev_c_data = c_out_data;
...@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -401,7 +402,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
...@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -431,7 +432,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = cur_prev_c_data; one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
......
set(jit_file ${PADDLE_BINARY_DIR}/paddle/fluid/operators/jit/kernels.h)
file(WRITE ${jit_file} "// Generated by the paddle/fluid/operators/jit/CMakeLists.txt. DO NOT EDIT!\n\n")
file(APPEND ${jit_file} "\#pragma once\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/helper.h\"\n")
file(APPEND ${jit_file} "\#include \"paddle/fluid/operators/jit/registry.h\"\n\n")
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce place)
file(GLOB jit_kernel_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
list(REMOVE_ITEM jit_kernel_cc_srcs test.cc benchmark.cc)
cc_library(jit_kernel_base SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
# refer must go first
add_subdirectory(refer)
add_subdirectory(more)
if(WITH_XBYAK)
add_subdirectory(gen)
endif()
cc_library(jit_kernel_helper SRCS ${jit_kernel_cc_srcs} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS test.cc DEPS jit_kernel_helper)
if(NOT WIN32)
cc_binary(jit_kernel_benchmark SRCS benchmark.cc DEPS jit_kernel_helper)
endif()
# JIT Kernel
结合函数模板和JIT生成需要的kernel函数。
这里的kernel是比Operator中kernel更小级别的算子单元,更侧重的是在不同硬件上的性能。可以有多重第三方库的实现,每种实现有自己的`UseMe`函数负责什么条件下可以被调用。
这里实现的函数可以非常细粒度的函数方法,比如Vector MUL, 也可以是一个复杂的逻辑比如LSTM等。复杂的逻辑也可以由自己的底层函数拼接而成。
目前仅支持CPU上的高性能计算。
## 目录结构
```txt
PaddlePaddle/Paddle/paddle/fluid/
├── ...
├── operator/
│ ├── .../
└── jit/
├── ...
├── gen/
│ └── ...
|── more/
│ ├── ...
│ ├── mkl/
│ │ └── ...
│ ├── mkldnn/
│ │ └── ...
│ ├── mix/
│ │ └── ...
│ ├── intrinsic/
│ │ └── ...
│ └── openblas/
│ └── ...
└── refer/
└── ...
```
基本类的定义都放在根目录下,根目录下包括gen,more和refer三个目录。每个目录下都是一种或者多种实现,每种kernel算子都需要有reference的实现,用作单元测试的基准,其他的实现都是可选的。
- gen: 代表使用jit生成的code,需要依赖xbyak库。该实现最关心的就是性能。
- refer: 代表reference的实现,每种kernel算子都需要有在CPU上的reference的实现,他主要关心的算法逻辑的正确性。
- more: 下面可以放入跟多实现,可以包括mkl,mkldnn,intrinsic,openblas等,也可以是自身已有的kernel组合。
## 动态获取
提供一个`jit::Get`方法,根据kernel类别获取,每种实现都有自己的使用范围,根据范围动态和当前条件选择需要的kernel函数。
## 测试
- 逻辑测试
所有实现都要与refer的code对比,需要满足精度要求, 包括float和double的数据类型
- 性能测试
所有实现的性能对比,并且与最终的`jit::Get`方法对比,该方法拿到的性能需要在各种条件下都是最好的。
# 如何添加新的算子
-`KernelType` 中添加 `your_key` .
- 实现Reference 的逻辑,这个是必须是在CPU上的实现,并且不能依赖任何第三方库。实现后在`refer/CmakeLists.txt`中添加`USE_JITKERNEL_REFER(your_key)`来使用该kernel.
- (optional) 实现更多的算法在`more`目录下,可以依赖mkl,intrinsic或者mkldnn等第三方库。
- (optional) 实现基于Xbyak的生成code,在`gen`目下。 jitcode需要实现自己的`JitCodeCreator`,并注册在与refer相同的`KernelType`上。
- 必要时可以添加新的`KernelTuples`,可以参考`XYZNTuples`,新加的Attr类型需要特例化`JitCodeKey`方法。
-`test.cc`中添加unit test,至少需要测试`float``double`两种数据类型,如有必要需要支持额外的数据类型,比如`int8`的相关函数。
-`benchmark.cc`中添加相应的性能对比,同一种kernel需要对比所有实现,并且确保`jit::Get`得到的实现一直是速度最快的。
# 优点
- 统一的Get方法,接口简单。
- 同一套逻辑可以有多套实现,可以依赖多套第三方库,互不影响。
- 目录结构清晰,不会在某个文件中有多个宏定义,导致的可读性差问题。
- 优化方便,可以直接针对某种属性针对性优化,并不影响其他属性下的性能。
- 可以支持多种平台,包括Linux,Mac 和 Windows,至少可以保证每种平台都可以正常work。后期也可以针对不同平台有针对的优化。框架层面可以使用统一接口,不必关心底层实现。
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/device_tracer.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/port.h"
DEFINE_int32(burning, 10, "Burning times.");
DEFINE_int32(repeat, 3000, "Repeat times.");
DEFINE_int32(max_size, 1000, "The Max size would be tested.");
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f), unsigned int seed = 100) {
std::mt19937 rng(seed);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i <= FLAGS_max_size; ++i) {
s.push_back(i);
}
return s;
}
template <typename KernelTuples, typename... Args>
struct BenchFunc {
// return this function avg time
double operator()(const typename KernelTuples::func_type tgt, Args... args) {
for (int i = 0; i < FLAGS_burning; ++i) {
tgt(args...);
}
auto start = paddle::platform::PosixInNsec() / 1e-3;
for (int i = 0; i < FLAGS_repeat; ++i) {
tgt(args...);
}
auto end = paddle::platform::PosixInNsec() / 1e-3;
return static_cast<double>(end - start) / FLAGS_repeat;
}
};
namespace jit = paddle::operators::jit;
template <jit::KernelType KT, typename KernelTuples, typename PlaceType,
typename... Args>
void BenchAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
BenchFunc<KernelTuples, Args...> benchmark;
std::vector<std::pair<std::string, double>> infos;
// test refer
auto refer = jit::GetRefer<KT, KernelTuples>();
if (!refer) {
LOG(FATAL) << "Refer can not be empty!";
}
infos.push_back(std::make_pair("Refer", benchmark(refer, args...)));
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
infos.push_back(std::make_pair("JitCode", benchmark(jitcode, args...)));
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
infos.push_back(
std::make_pair(i->ImplType(), benchmark(more, args...)));
}
}
}
// Test result from Get function
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
if (!tgt) {
LOG(FATAL) << "Target can not be empty!";
}
infos.push_back(std::make_pair("Target", benchmark(tgt, args...)));
// print
std::ostringstream loginfos;
loginfos << "Kernel Type " << jit::to_string(KT) << ": " << attr << ": ";
for (auto pair : infos) {
loginfos << pair.first << " takes " << pair.second << " us; ";
}
LOG(INFO) << loginfos.str();
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYZNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d), z(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
BenchAllImpls<KT, jit::XYZNTuples<T>, PlaceType>(d, x.data(), y.data(),
z.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchAXYNKernel() {
for (int d : TestSizes()) {
const T a = static_cast<T>(3);
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::AXYNTuples<T>, PlaceType>(d, &a, x.data(), y.data(),
d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchXYNKernel() {
for (int d : TestSizes()) {
std::vector<T> x(d), y(d);
RandomVec<T>(d, x.data());
BenchAllImpls<KT, jit::XYNTuples<T>, PlaceType>(d, x.data(), y.data(), d);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchLSTMKernel() {
for (bool use_peephole : {true, false}) {
for (int d : TestSizes()) {
const jit::lstm_attr_t attr(d, jit::kVSigmoid, jit::kVTanh, jit::kVTanh,
use_peephole);
std::vector<T> x(4 * d), ct_1(d), ct(d), ht(d), wp(3 * d), checked(2 * d);
RandomVec<T>(4 * d, x.data(), -2.f, 2.f);
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
T* x_data = x.data();
T* checked_data = checked.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
BenchAllImpls<KT, jit::LSTMTuples<T>, PlaceType>(attr, &step, &attr);
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void BenchGRUKernel() {
for (int d : TestSizes()) {
const jit::gru_attr_t attr(d, jit::kVSigmoid, jit::kVTanh);
std::vector<T> x(3 * d), ht_1(d), ht(d);
RandomVec<T>(3 * d, x.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
const T* ht_1_data = ht_1.data();
T* x_data = x.data();
T* ht_data = ht.data();
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
BenchAllImpls<KT, jit::GRUTuples<T>, PlaceType>(attr, &step, &attr);
}
}
// Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...]
// Options:
// --burning: the burning time before count
// --repeat: the repeat times
// --max_size: the max size would be tested
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
google::InitGoogleLogging(argv[0]);
LOG(INFO) << "Burning " << FLAGS_burning << " times, Repeat " << FLAGS_repeat
<< " times.";
using T = float;
using PlaceType = paddle::platform::CPUPlace;
// xyzn
BenchXYZNKernel<jit::kVMul, T, PlaceType>();
BenchXYZNKernel<jit::kVAdd, T, PlaceType>();
BenchXYZNKernel<jit::kVAddRelu, T, PlaceType>();
BenchXYZNKernel<jit::kVSub, T, PlaceType>();
// axyn
BenchAXYNKernel<jit::kVScal, T, PlaceType>();
BenchAXYNKernel<jit::kVAddBias, T, PlaceType>();
// xyn
BenchXYNKernel<jit::kVRelu, T, PlaceType>();
BenchXYNKernel<jit::kVIdentity, T, PlaceType>();
BenchXYNKernel<jit::kVExp, T, PlaceType>();
BenchXYNKernel<jit::kVSigmoid, T, PlaceType>();
BenchXYNKernel<jit::kVTanh, T, PlaceType>();
// lstm and peephole
BenchLSTMKernel<jit::kLSTMCtHt, T, PlaceType>();
BenchLSTMKernel<jit::kLSTMC1H1, T, PlaceType>();
// gru functions
BenchGRUKernel<jit::kGRUH1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart1, T, PlaceType>();
BenchGRUKernel<jit::kGRUHtPart2, T, PlaceType>();
}
file(GLOB jitcode_cc_srcs RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_jitcode SRCS ${jitcode_cc_srcs} DEPS jit_kernel_base xbyak)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} xbyak jit_kernel_jitcode PARENT_SCOPE)
function(USE_JITKERNEL_GEN TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_GEN(${TARGET});\n")
endfunction()
# use gen jitcode kernel by name
USE_JITKERNEL_GEN(kVMul)
USE_JITKERNEL_GEN(kVAdd)
#USE_JITKERNEL_GEN(kVSub) # TODO(TJ): enable me
USE_JITKERNEL_GEN(kVAddRelu)
USE_JITKERNEL_GEN(kVScal)
USE_JITKERNEL_GEN(kVAddBias)
USE_JITKERNEL_GEN(kVRelu)
USE_JITKERNEL_GEN(kVIdentity)
USE_JITKERNEL_GEN(kVExp)
USE_JITKERNEL_GEN(kVSigmoid)
USE_JITKERNEL_GEN(kVTanh)
USE_JITKERNEL_GEN(kLSTMCtHt)
USE_JITKERNEL_GEN(kLSTMC1H1)
USE_JITKERNEL_GEN(kGRUH1)
USE_JITKERNEL_GEN(kGRUHtPart1)
USE_JITKERNEL_GEN(kGRUHtPart2)
USE_JITKERNEL_GEN(kNCHW16CMulNC)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)};
int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0};
void VActJitCode::genCode() {
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
act<ymm_t>(ymm_dst, ymm_src, type_);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
vmovups(xmm_src, ptr[param1 + offset]);
} else if (rest >= 2) {
block = 2;
vmovq(xmm_src, ptr[param1 + offset]);
} else {
block = 1;
vmovss(xmm_src, ptr[param1 + offset]);
}
act<xmm_t>(xmm_dst, xmm_src, type_);
if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param2 + offset], xmm_dst);
} else {
vmovss(ptr[param2 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
#define DECLARE_ACT_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override; \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_ACT_CREATOR(VRelu);
DECLARE_ACT_CREATOR(VIdentity);
DECLARE_ACT_CREATOR(VExp);
DECLARE_ACT_CREATOR(VSigmoid);
DECLARE_ACT_CREATOR(VTanh);
// TODO(TJ): tuning use me
size_t VReluCreator::CodeSize(const int& d) const {
return 96 /* init size */ +
(d / YMM_FLOAT_BLOCK + 3) * 4 /* instructions */ *
8 /* average bytes for each instruction */;
}
size_t VIdentityCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 4 * 8;
}
size_t VExpCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 70 * 8;
}
size_t VSigmoidCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 82 * 8;
}
size_t VTanhCreator::CodeSize(const int& d) const {
return 96 + (d / YMM_FLOAT_BLOCK + 3) * 84 * 8;
}
#undef DECLARE_ACT_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVRelu, gen::VReluCreator);
REGISTER_JITKERNEL_GEN(kVIdentity, gen::VIdentityCreator);
REGISTER_JITKERNEL_GEN(kVExp, gen::VExpCreator);
REGISTER_JITKERNEL_GEN(kVSigmoid, gen::VSigmoidCreator);
REGISTER_JITKERNEL_GEN(kVTanh, gen::VTanhCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
You may obtain a copy of the License at * You may obtain a copy of the License at
*
http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
*
Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/operators/math/jit_gen.h" #include "glog/logging.h"
#include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace jit {
namespace jitkernel {
namespace gen { namespace gen {
using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
mul = 0,
add,
sub,
relu,
exp,
sigmoid,
tanh,
identity
} operand_type;
extern const float exp_float_consts[]; extern const float exp_float_consts[];
extern const int exp_int_0x7f[]; extern const int exp_int_0x7f[];
extern int g_tmp_mem[]; extern int g_tmp_mem[];
...@@ -79,94 +59,15 @@ extern int g_tmp_mem[]; ...@@ -79,94 +59,15 @@ extern int g_tmp_mem[];
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float) #define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu) class VActFunc : public JitCode {
class VXXJitCode : public JitCode {
public:
const char* name() const override {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::mul) {
base += "_Mul";
} else if (type_ == operand_type::add) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {}
static bool init(int d, int scalar_index = 0);
void generate() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
class VActJitCode : public JitCode {
public: public:
const char* name() const override { explicit VActFunc(size_t code_size, void* code_ptr)
std::string base = "VActJitCode"; : JitCode(code_size, code_ptr) {}
switch (type_) { virtual const char* name() const = 0;
case operand_type::relu: virtual void genCode() = 0;
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
return base.c_str();
}
explicit VActJitCode(int d, operand_type type, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), num_(d), type_(type) {}
static bool init(int d, operand_type type);
void generate() override;
protected: protected:
// compute relu with ymm, xmm // compute RELU with ymm, xmm
template <typename JMM> template <typename JMM>
void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT void relu_jmm(JMM& dst, JMM& src, int zero_idx = 15) { // NOLINT
JMM zero = JMM(zero_idx); JMM zero = JMM(zero_idx);
...@@ -174,7 +75,7 @@ class VActJitCode : public JitCode { ...@@ -174,7 +75,7 @@ class VActJitCode : public JitCode {
vmaxps(dst, src, zero); vmaxps(dst, src, zero);
} }
// compute exp with ymm, xmm // compute EXP with ymm, xmm
template <typename JMM> template <typename JMM>
void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT void exp_jmm(JMM& dst, JMM& src, int src_idx = 11, int fx_idx = 12, // NOLINT
int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) { int fy_idx = 13, int mask_idx = 14, int tmp_idx = 15) {
...@@ -258,7 +159,7 @@ class VActJitCode : public JitCode { ...@@ -258,7 +159,7 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute sigmoid with ymm, xmm // compute SIGMOID with ymm, xmm
template <typename JMM> template <typename JMM>
void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void sigmoid_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -283,7 +184,7 @@ class VActJitCode : public JitCode { ...@@ -283,7 +184,7 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute tanh with ymm, xmm // compute TANH with ymm, xmm
template <typename JMM> template <typename JMM>
void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT void tanh_jmm(JMM& dst, JMM& src, int src_idx = 11, // NOLINT
int fx_idx = 12, int fy_idx = 13, int mask_idx = 14, int fx_idx = 12, int fy_idx = 13, int mask_idx = 14,
...@@ -310,223 +211,109 @@ class VActJitCode : public JitCode { ...@@ -310,223 +211,109 @@ class VActJitCode : public JitCode {
pop(reg_ptr_global); pop(reg_ptr_global);
} }
// compute IDENTITY with ymm, xmm
template <typename JMM>
void identity_jmm(JMM& dst, JMM& src, int zero_idx) { // NOLINT
JMM zero = JMM(zero_idx);
vxorps(zero, zero, zero);
vaddps(dst, src, zero);
// TODO(TJ): use below
// dst.setIdx(src.getIdx());
}
template <typename JMM> template <typename JMM>
void act(JMM& dst, JMM& src, operand_type type) { // NOLINT void act(JMM& dst, JMM& src, operand_type type) { // NOLINT
// use 11~15 // use 11~15
switch (type) { switch (type) {
case operand_type::relu: case operand_type::RELU:
relu_jmm<JMM>(dst, src, 15); relu_jmm<JMM>(dst, src, 15);
break; break;
case operand_type::exp: case operand_type::EXP:
exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); exp_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::sigmoid: case operand_type::SIGMOID:
sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); sigmoid_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::tanh: case operand_type::TANH:
tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15); tanh_jmm<JMM>(dst, src, 11, 12, 13, 14, 15);
break; break;
case operand_type::identity: case operand_type::IDENTITY:
identity_jmm<JMM>(dst, src, 15);
break; break;
default: default:
// throw error LOG(FATAL) << "Do not support this operand type: " << type;
break; break;
} }
} }
protected:
int num_;
operand_type type_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
xmm_t xmm_src = xmm_t(0);
ymm_t ymm_src = ymm_t(0);
xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_dst = ymm_t(1);
}; };
class LSTMJitCode : public VActJitCode { class VActJitCode : public VActFunc {
public: public:
const char* name() const override { explicit VActJitCode(int d, operand_type type, size_t code_size,
std::string base = "LSTMJitCode"; void* code_ptr = nullptr)
if (use_peephole_) { : VActFunc(code_size, code_ptr), num_(d), type_(type) {
base += "_Peephole"; if (!(type_ == operand_type::RELU || type_ == operand_type::EXP ||
} type_ == operand_type::SIGMOID || type_ == operand_type::TANH ||
if (compute_c1h1_) { type_ == operand_type::IDENTITY)) {
base += "_C1H1"; LOG(FATAL) << "Do not support this operand type: " << type_;
} }
auto AddTypeStr = [&](operand_type type) { this->genCode();
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
compute_c1h1_(compute_c1h1) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
use_peephole_ = attr.use_peephole;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
} }
static bool init(int d);
void generate() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
class GRUJitCode : public VActJitCode {
public:
const char* name() const override { const char* name() const override {
std::string base = "GRUJitCode"; std::string base = "VActJitCode";
if (id_ == 0) { switch (type_) {
base += "_H1"; case operand_type::RELU:
} else if (id_ == 1) { base += "_Relu";
base += "_HtPart1"; break;
} else if (id_ == 2) { case operand_type::EXP:
base += "_HtPart2"; base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
} }
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::relu:
base += "_Relu";
break;
case operand_type::exp:
base += "_Exp";
break;
case operand_type::sigmoid:
base += "_Sigmoid";
break;
case operand_type::tanh:
base += "_Tanh";
break;
case operand_type::identity:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str(); return base.c_str();
} }
void genCode() override;
explicit GRUJitCode(int id, const gru_attr_t& attr,
size_t code_size = 256 * 1024, void* code_ptr = nullptr)
: VActJitCode(attr.d, operand_type::sigmoid /* this is bugy*/, code_size,
code_ptr),
id_(id) {
auto typeExchange = [](const std::string& type) -> gen::operand_type {
if (type == "sigmoid") {
return operand_type::sigmoid;
} else if (type == "relu") {
return operand_type::relu;
} else if (type == "tanh") {
return operand_type::tanh;
} else if (type == "identity" || type == "") {
return operand_type::identity;
} // else throw error
return operand_type::identity;
};
num_ = attr.d;
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
}
static bool init(int d);
void generate() override;
protected: protected:
int id_;
int num_; int num_;
operand_type act_gate_; operand_type type_;
operand_type act_cand_;
reg64_t param1{abi_param1}; reg64_t param1{abi_param1};
}; reg64_t param2{abi_param2};
#ifdef PADDLE_WITH_MKLDNN xmm_t xmm_src = xmm_t(0);
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator { ymm_t ymm_src = ymm_t(0);
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
: Xbyak::CodeGenerator(code_size) {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx); xmm_t xmm_dst = xmm_t(1);
ymm_t ymm_dst = ymm_t(1);
};
xor_(rax, rax); #define DECLARE_ACT_JITCODE(name, op_type) \
xor_(r10, r10); class name##JitCode : public VActJitCode { \
vmovups(zmm3, ptr[rsi]); public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VActJitCode(d, op_type, code_size, code_ptr) {} \
};
L("h_loop"); DECLARE_ACT_JITCODE(VRelu, operand_type::RELU);
xor_(rbx, rbx); DECLARE_ACT_JITCODE(VIdentity, operand_type::IDENTITY);
L("w_loop"); DECLARE_ACT_JITCODE(VExp, operand_type::EXP);
vmovups(zmm2, ptr[rdi + rax]); DECLARE_ACT_JITCODE(VSigmoid, operand_type::SIGMOID);
vmulps(zmm1, zmm2, zmm3); DECLARE_ACT_JITCODE(VTanh, operand_type::TANH);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx); #undef DECLARE_ACT_JITCODE
ret();
}
};
#endif
} // namespace gen } // namespace gen
} // namespace jitkernel } // namespace jit
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/blas.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void VXXJitCode::genCode() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::MUL) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::ADD) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
} else if (rest >= 2) {
block = 2;
if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
}
} else {
block = 1;
if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
}
switch (type_) {
case operand_type::MUL:
vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::ADD:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
if (rest >= 4) {
vmovups(ptr[param3 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param3 + offset], xmm_dst);
} else {
vmovss(ptr[param3 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
void NCHW16CMulNCJitCode::genCode() {
// RDI is ptr x_input
// RSI is ptr y_input
// RDX is ptr output
// RCX is height
// r8 is width
push(rbx);
xor_(rax, rax);
xor_(r10, r10);
vmovups(zmm3, ptr[rsi]);
L("h_loop");
xor_(rbx, rbx);
L("w_loop");
vmovups(zmm2, ptr[rdi + rax]);
vmulps(zmm1, zmm2, zmm3);
vmovups(ptr[rdx + rax], zmm1);
add(rax, 64);
inc(rbx);
cmp(r8, rbx);
jnz("w_loop");
inc(r10);
cmp(r10, rcx);
jnz("h_loop");
pop(rbx);
ret();
}
class NCHW16CMulNCCreator : public JitCodeCreator<int> {
public:
bool UseMe(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& d) const override { return 256 * 1024; }
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<NCHW16CMulNCJitCode>(attr, CodeSize(attr));
}
};
#define DECLARE_BLAS_CREATOR(name) \
class name##Creator : public JitCodeCreator<int> { \
public: \
bool UseMe(const int& attr) const override { \
return platform::MayIUse(platform::avx); \
} \
size_t CodeSize(const int& d) const override { \
return 96 + d / YMM_FLOAT_BLOCK * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_BLAS_CREATOR(VMul);
DECLARE_BLAS_CREATOR(VAdd);
DECLARE_BLAS_CREATOR(VSub);
DECLARE_BLAS_CREATOR(VAddRelu);
DECLARE_BLAS_CREATOR(VScal);
DECLARE_BLAS_CREATOR(VAddBias);
#undef DECLARE_BLAS_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVMul, gen::VMulCreator);
REGISTER_JITKERNEL_GEN(kVAdd, gen::VAddCreator);
// TODO(TJ): enable sub
// REGISTER_JITKERNEL_GEN(kVSub, gen::VSubCreator);
REGISTER_JITKERNEL_GEN(kVAddRelu, gen::VAddReluCreator);
REGISTER_JITKERNEL_GEN(kVScal, gen::VScalCreator);
REGISTER_JITKERNEL_GEN(kVAddBias, gen::VAddBiasCreator);
REGISTER_JITKERNEL_GEN(kNCHW16CMulNC, gen::NCHW16CMulNCCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// function: vec = Operand(vec(or scalar), vec(or scalar)) (maybe with relu)
class VXXJitCode : public JitCode {
public:
explicit VXXJitCode(int d, operand_type type, int scalar_index,
bool with_relu, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
num_(d),
type_(type),
scalar_index_(scalar_index),
with_relu_(with_relu) {
if (!(type_ == operand_type::MUL || type_ == operand_type::ADD)) {
LOG(FATAL) << "Do not support this operand type: " << type_;
}
this->genCode();
}
virtual const char* name() const {
std::string base = "VXXJitCode";
if (scalar_index_ == 1) {
base += "_Scalar";
} else {
base += "_Vec";
}
if (type_ == operand_type::MUL) {
base += "_Mul";
} else if (type_ == operand_type::ADD) {
base += "_Add";
}
if (scalar_index_ == 2) {
base += "_Scalar";
} else {
base += "_Vec";
}
base += (with_relu_ ? "_Relu" : "");
return base.c_str();
}
void genCode() override;
private:
int num_;
operand_type type_;
int scalar_index_;
bool with_relu_;
reg64_t param1{abi_param1};
reg64_t param2{abi_param2};
reg64_t param3{abi_param3};
xmm_t xmm_src1 = xmm_t(0);
xmm_t xmm_src2 = xmm_t(1);
xmm_t xmm_dst = xmm_t(2);
xmm_t xmm_zero = xmm_t(3);
ymm_t ymm_src1 = ymm_t(0);
ymm_t ymm_src2 = ymm_t(1);
ymm_t ymm_dst = ymm_t(2);
ymm_t ymm_zero = ymm_t(3);
};
#define DECLARE_BLAS_JITCODE(name, op_type, scalar_idx, with_relu) \
class name##JitCode : public VXXJitCode { \
public: \
explicit name##JitCode(int d, size_t code_size, void* code_ptr = nullptr) \
: VXXJitCode(d, op_type, scalar_idx, with_relu, code_size, code_ptr) { \
} \
};
DECLARE_BLAS_JITCODE(VMul, operand_type::MUL, 0, false);
DECLARE_BLAS_JITCODE(VAdd, operand_type::ADD, 0, false);
DECLARE_BLAS_JITCODE(VSub, operand_type::SUB, 0, false);
DECLARE_BLAS_JITCODE(VAddRelu, operand_type::ADD, 0, true);
DECLARE_BLAS_JITCODE(VScal, operand_type::MUL, 1, false);
DECLARE_BLAS_JITCODE(VAddBias, operand_type::ADD, 1, false);
#undef DECLARE_BLAS_JITCODE
// nChw16c = nChw16c .* NC
class NCHW16CMulNCJitCode : public JitCode {
public:
DECLARE_JIT_CODE(NCHW16CMulNCJitCode);
explicit NCHW16CMulNCJitCode(int d /*unused*/, size_t code_size,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}
void genCode() override;
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/gru.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void GRUJitCode::genCode() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
#define DECLARE_GRU_CREATOR(name) \
class name##Creator : public JitCodeCreator<gru_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const gru_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const gru_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const gru_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_GRU_CREATOR(GRUH1);
DECLARE_GRU_CREATOR(GRUHtPart1);
DECLARE_GRU_CREATOR(GRUHtPart2);
#undef DECLARE_GRU_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kGRUH1, gen::GRUH1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart1, gen::GRUHtPart1Creator);
REGISTER_JITKERNEL_GEN(kGRUHtPart2, gen::GRUHtPart2Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class GRUJitCode : public VActFunc {
public:
explicit GRUJitCode(int id, const gru_attr_t& attr, size_t code_size,
void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr), id_(id), num_(attr.d) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
this->genCode();
}
const char* name() const override {
std::string base = "GRUJitCode";
if (id_ == 0) {
base += "_H1";
} else if (id_ == 1) {
base += "_HtPart1";
} else if (id_ == 2) {
base += "_HtPart2";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
return base.c_str();
}
void genCode() override;
protected:
int id_;
int num_;
operand_type act_gate_;
operand_type act_cand_;
reg64_t param1{abi_param1};
};
#define DECLARE_GRU_JITCODE(name, id) \
class name##JitCode : public GRUJitCode { \
public: \
explicit name##JitCode(const gru_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: GRUJitCode(id, attr, code_size, code_ptr) {} \
};
DECLARE_GRU_JITCODE(GRUH1, 0);
DECLARE_GRU_JITCODE(GRUHtPart1, 1);
DECLARE_GRU_JITCODE(GRUHtPart2, 2);
#undef DECLARE_GRU_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/platform/cpu_info.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
// Application Binary Interface
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX);
constexpr Xbyak::Operand::Code g_abi_regs[] = {
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
using reg64_t = const Xbyak::Reg64;
using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm;
using Label = Xbyak::Label;
typedef enum {
MUL = 0,
ADD,
SUB,
RELU,
EXP,
SIGMOID,
TANH,
IDENTITY
} operand_type;
#define DECLARE_JIT_CODE(codename) \
const char* name() const override { return #codename; }
class JitCode : public GenBase, public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size, void* code_ptr = nullptr)
: Xbyak::CodeGenerator(
(code_size % 4096 != 0 ? (code_size / 4096 + 1) * 4096 : code_size),
code_ptr) {}
virtual const char* name() const = 0;
virtual void genCode() = 0;
size_t getSize() const override { return CodeGenerator::getSize(); }
const unsigned char* getCodeInternal() override {
const Xbyak::uint8* code = CodeGenerator::getCode();
return code;
}
protected:
Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200;
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
virtual void preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i]));
}
if (platform::MayIUse(platform::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
}
}
virtual void postCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
}
ret();
}
void L(const char* label) { Xbyak::CodeGenerator::L(label); }
void L(const Xbyak::Label& label) { Xbyak::CodeGenerator::L(label); }
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false) {
int scale = 0;
// Learn from https://github.com/intel/mkl-dnn
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
offt = offt - 2 * EVEX_max_8b_offt;
scale = 1;
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
offt = offt - 4 * EVEX_max_8b_offt;
scale = 2;
}
auto re = Xbyak::RegExp() + base + offt;
if (scale) {
re = re + reg_EVEX_max_8b_offt * scale;
}
if (bcast) {
return zword_b[re];
} else {
return zword[re];
}
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/lstm.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void LSTMJitCode::genCode() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t ymm_c = ymm_t(0);
ymm_t ymm_i = ymm_t(1);
ymm_t ymm_f = ymm_t(2);
ymm_t ymm_o = ymm_t(3);
ymm_t ymm_ct_1 = ymm_t(4);
ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
}
if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if (use_peephole_) {
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vaddps(ymm_f, ymm_f, ymm_wp1);
}
act<ymm_t>(ymm_f, ymm_f, act_gate_);
// ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
// act_gate(o) or act_gate(ct * wp2 + o)
if (use_peephole_) {
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (use_peephole_) {
postCode();
} else {
ret();
}
}
#define DECLARE_LSTM_CREATOR(name) \
class name##Creator : public JitCodeCreator<lstm_attr_t> { \
public: \
/* TODO(TJ): enable more */ \
bool UseMe(const lstm_attr_t& attr) const override { \
return platform::MayIUse(platform::avx) && attr.d % 8 == 0; \
} \
size_t CodeSize(const lstm_attr_t& attr) const override { \
return 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8; \
} \
std::unique_ptr<GenBase> CreateJitCode( \
const lstm_attr_t& attr) const override { \
return make_unique<name##JitCode>(attr, CodeSize(attr)); \
} \
}
DECLARE_LSTM_CREATOR(LSTMCtHt);
DECLARE_LSTM_CREATOR(LSTMC1H1);
#undef DECLARE_LSTM_CREATOR
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kLSTMCtHt, gen::LSTMCtHtCreator);
REGISTER_JITKERNEL_GEN(kLSTMC1H1, gen::LSTMC1H1Creator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/act.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class LSTMJitCode : public VActFunc {
public:
explicit LSTMJitCode(bool compute_c1h1, const lstm_attr_t& attr,
size_t code_size, void* code_ptr = nullptr)
: VActFunc(code_size, code_ptr),
num_(attr.d),
compute_c1h1_(compute_c1h1),
use_peephole_(attr.use_peephole) {
auto typeExchange = [](KernelType type) -> gen::operand_type {
if (type == KernelType::kVSigmoid) {
return operand_type::SIGMOID;
} else if (type == KernelType::kVRelu) {
return operand_type::RELU;
} else if (type == KernelType::kVTanh) {
return operand_type::TANH;
} else if (type == KernelType::kVIdentity) {
return operand_type::IDENTITY;
} else {
LOG(FATAL) << "Do not support this jit::KernelType: " << type;
}
return operand_type::IDENTITY;
};
act_gate_ = typeExchange(attr.act_gate);
act_cand_ = typeExchange(attr.act_cand);
act_cell_ = typeExchange(attr.act_cell);
this->genCode();
}
const char* name() const override {
std::string base = "LSTMJitCode";
if (use_peephole_) {
base += "_Peephole";
}
if (compute_c1h1_) {
base += "_C1H1";
}
auto AddTypeStr = [&](operand_type type) {
switch (type) {
case operand_type::RELU:
base += "_Relu";
break;
case operand_type::EXP:
base += "_Exp";
break;
case operand_type::SIGMOID:
base += "_Sigmoid";
break;
case operand_type::TANH:
base += "_Tanh";
break;
case operand_type::IDENTITY:
base += "_Identity";
break;
default:
break;
}
};
AddTypeStr(act_gate_);
AddTypeStr(act_cand_);
AddTypeStr(act_cell_);
return base.c_str();
}
void genCode() override;
protected:
int num_;
bool compute_c1h1_;
bool use_peephole_;
operand_type act_gate_;
operand_type act_cand_;
operand_type act_cell_;
reg64_t param1{abi_param1};
};
#define DECLARE_LSTM_JITCODE(name, compute_c1h1) \
class name##JitCode : public LSTMJitCode { \
public: \
explicit name##JitCode(const lstm_attr_t& attr, size_t code_size, \
void* code_ptr = nullptr) \
: LSTMJitCode(compute_c1h1, attr, code_size, code_ptr) {} \
};
DECLARE_LSTM_JITCODE(LSTMCtHt, false);
DECLARE_LSTM_JITCODE(LSTMC1H1, true);
#undef DECLARE_LSTM_JITCODE
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen_base.h"
#include <fstream>
#include <iostream>
#include <sstream>
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace paddle {
namespace operators {
namespace jit {
// refer do not need useme, it would be the last one.
void GenBase::dumpCode(const unsigned char* code) const {
if (code) {
static int counter = 0;
std::ostringstream filename;
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
counter++;
std::ofstream fout(filename.str(), std::ios::out);
if (fout.is_open()) {
fout.write(reinterpret_cast<const char*>(code), this->getSize());
fout.close();
}
}
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <memory> // for unique_ptr
#include "paddle/fluid/operators/jit/kernel_base.h"
DECLARE_bool(dump_jitcode);
namespace paddle {
namespace operators {
namespace jit {
class GenBase : public Kernel {
public:
virtual ~GenBase() = default;
virtual const char* name() const = 0;
virtual size_t getSize() const = 0;
virtual const unsigned char* getCodeInternal() = 0;
template <typename Func>
Func getCode() {
const unsigned char* code = this->getCodeInternal();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
}
return reinterpret_cast<Func>(const_cast<unsigned char*>(code));
}
protected:
void dumpCode(const unsigned char* code) const;
};
// Creator is used to creat the jitcode and save in pool.
// Every JitCode should have one creator.
class GenCreator {
public:
virtual ~GenCreator() = default;
};
template <typename Attr>
class JitCodeCreator : public GenCreator {
public:
virtual ~JitCodeCreator() = default;
// condition when this jit code can be used.
virtual bool UseMe(const Attr& attr) const = 0;
// estimate this code size
virtual size_t CodeSize(const Attr& attr) const = 0;
// create this code
virtual std::unique_ptr<GenBase> CreateJitCode(const Attr& attr) const = 0;
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/helper.h"
#include <algorithm> // tolower
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
#define ONE_CASE(key) \
case key: \
return #key
const char* to_string(KernelType kt) {
switch (kt) {
ONE_CASE(kVMul);
ONE_CASE(kVAdd);
ONE_CASE(kVAddRelu);
ONE_CASE(kVSub);
ONE_CASE(kVScal);
ONE_CASE(kVAddBias);
ONE_CASE(kVRelu);
ONE_CASE(kVIdentity);
ONE_CASE(kVExp);
ONE_CASE(kVSigmoid);
ONE_CASE(kVTanh);
ONE_CASE(kLSTMCtHt);
ONE_CASE(kLSTMC1H1);
ONE_CASE(kGRUH1);
ONE_CASE(kGRUHtPart1);
ONE_CASE(kGRUHtPart2);
ONE_CASE(kCRFDecoding);
ONE_CASE(kLayerNorm);
ONE_CASE(kNCHW16CMulNC);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
}
return nullptr;
}
#undef ONE_CASE
KernelType to_kerneltype(const std::string& act) {
std::string lower = act;
std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
if (lower == "relu" || lower == "vrelu") {
return kVRelu;
} else if (lower == "identity" || lower == "videntity" || lower == "") {
return kVIdentity;
} else if (lower == "exp" || lower == "vexp") {
return kVExp;
} else if (lower == "sigmoid" || lower == "vsigmoid") {
return kVSigmoid;
} else if (lower == "tanh" || lower == "vtanh") {
return kVTanh;
}
PADDLE_THROW("Not support type: %s, or forget to add this case", act);
return kNone;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
std::is_same<typename KernelTuples::data_type, float>::value &&
std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
size_t key = JitCodeKey<Attr>(attr);
auto& codes = JitCodePool<KT>().Instance();
if (codes.Has(key)) {
return codes.AllKernels().at(key)->template getCode<Func>();
}
// creator is not related with attr, so can use KernelKey as key
KernelKey kkey(KT, PlaceType());
// pool: (KernelKey(type, place), vector<GenCreatorPtr>)
auto& creator_map = JitCodeCreatorPool().Instance().AllCreators();
auto iter = creator_map.find(kkey);
if (iter != creator_map.end()) {
auto& creators = iter->second;
for (auto& cur : creators) {
auto i = dynamic_cast<const JitCodeCreator<Attr>*>(cur.get());
if (i && i->UseMe(attr)) {
auto p = i->CreateJitCode(attr);
if (p) {
auto f = p->template getCode<Func>();
codes.Insert(key, std::move(p));
return f;
}
}
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples, typename PlaceType>
inline typename std::enable_if<
!std::is_same<typename KernelTuples::data_type, float>::value ||
!std::is_same<PlaceType, platform::CPUPlace>::value,
typename KernelTuples::func_type>::type
GetJitCode(const typename KernelTuples::attr_type& attr) {
return nullptr;
}
// Refer code do not related with attr, which is just for cast
// Refer is always on CPUPlace
template <KernelType KT, typename KernelTuples>
inline typename KernelTuples::func_type GetRefer() {
auto& ref_pool = ReferKernelPool().Instance().AllKernels();
KernelKey kkey(KT, platform::CPUPlace());
auto ref_iter = ref_pool.find(kkey);
PADDLE_ENFORCE(ref_iter != ref_pool.end(),
"Every Kernel should have reference function.");
auto& ref_impls = ref_iter->second;
for (auto& impl : ref_impls) {
auto i = dynamic_cast<const ReferKernel<KernelTuples>*>(impl.get());
if (i) {
return i->GetFunc();
}
}
return nullptr;
}
template <KernelType KT, typename KernelTuples,
typename PlaceType = platform::CPUPlace>
typename KernelTuples::func_type Get(
const typename KernelTuples::attr_type& attr) {
auto jitfunc = GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitfunc) {
return jitfunc;
}
// pool: (KernelKey(type, place), vector<KernelPtr>)
KernelKey kkey(KT, PlaceType());
auto& pool = KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
return i->GetFunc();
}
}
}
// The last implementation should be reference function on CPUPlace.
return GetRefer<KT, KernelTuples>();
}
const char* to_string(KernelType kt);
KernelType to_kerneltype(const std::string& act);
inline std::ostream& operator<<(std::ostream& os, const lstm_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "],act_cell["
<< to_string(attr.act_cell) << "],use_peephole["
<< (attr.use_peephole ? "True" : "False") << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const gru_attr_t& attr) {
os << "dim_size[" << attr.d << "],act_gate[" << to_string(attr.act_gate)
<< "],act_cand[" << to_string(attr.act_cand) << "]";
return os;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace operators {
namespace jit {
typedef enum {
kNone = 0,
kVMul = 1,
kVAdd = 2,
kVAddRelu,
kVSub,
kVScal,
kVAddBias,
kVRelu,
kVIdentity,
kVExp,
kVSigmoid,
kVTanh,
kLSTMCtHt,
kLSTMC1H1,
kGRUH1,
kGRUHtPart1,
kGRUHtPart2,
kCRFDecoding,
kLayerNorm,
kNCHW16CMulNC,
} KernelType;
template <typename T>
struct XYZNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int);
};
template <typename T>
struct AXYNTuples : public XYZNTuples<T> {};
template <typename T>
struct XYNTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, T*, int);
};
typedef struct {
void* gates; // gates: x_ch, x_ih, x_fh, x_oh
const void* ct_1;
void* ct;
void* ht;
/* weight_peephole and checked data are only used in peephole*/
const void* wp{nullptr}; // W_ic, W_fc, W_oc
void* checked{nullptr}; // size: 2 * d
} lstm_t;
typedef struct {
void* gates; // gates: {x_update, x_reset; x_state}
const void* ht_1;
void* ht;
} gru_t;
struct rnn_attr_s {
int d;
KernelType act_gate, act_cand;
rnn_attr_s() = default;
explicit rnn_attr_s(int _d, KernelType _act_gate, KernelType _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
KernelType act_cell;
lstm_attr_s() = default;
explicit lstm_attr_s(int _d, KernelType _act_gate, KernelType _act_cand,
KernelType _act_cell, bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
template <typename T>
struct LSTMTuples {
typedef T data_type;
typedef lstm_attr_t attr_type;
typedef void (*func_type)(lstm_t*, const lstm_attr_t*);
};
template <typename T>
struct GRUTuples {
typedef T data_type;
typedef gru_attr_t attr_type;
typedef void (*func_type)(gru_t*, const gru_attr_t*);
};
template <typename T>
struct CRFDecodingTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const int, const T*, const T*, T*, int*, int);
};
template <typename T>
struct LayerNormTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T*, T*, T*, T*, const T*, const T*, int,
const float, int);
};
// nChw16c = nChw16c .* NC
template <typename T>
struct NCHW16CMulNCTuples {
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(const T*, const T*, T*, int, int);
};
// Just for adding to kernel pool without template
class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
DISABLE_COPY_AND_ASSIGN(Kernel);
};
template <typename KernelTuples>
class KernelMore : public Kernel {
public:
using T = typename KernelTuples::data_type;
using Func = typename KernelTuples::func_type;
using Attr = typename KernelTuples::attr_type;
virtual Func GetFunc() const { return func; }
virtual bool UseMe(const Attr& attr) const = 0;
virtual const char* ImplType() const = 0;
protected:
Func func{nullptr};
};
template <typename KernelTuples>
class ReferKernel : public KernelMore<KernelTuples> {
public:
// Refer code can always be used
bool UseMe(const typename KernelTuples::attr_type& attr) const override {
return true;
}
const char* ImplType() const override { return "Refer"; }
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_key.h"
namespace paddle {
namespace operators {
namespace jit {
template <>
size_t JitCodeKey<int>(const int& d) {
return d;
}
constexpr int act_type_shift = 3; // suppot 2^3 act types
template <>
size_t JitCodeKey<lstm_attr_t>(const lstm_attr_t& attr) {
size_t key = attr.d;
int gate_key = static_cast<int>(attr.act_gate) << 1;
int cand_key = static_cast<int>(attr.act_cand) << (1 + act_type_shift);
int cell_key = static_cast<int>(attr.act_cell) << (1 + act_type_shift * 2);
return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key +
attr.use_peephole;
}
template <>
size_t JitCodeKey<gru_attr_t>(const gru_attr_t& attr) {
size_t key = attr.d;
return (key << (act_type_shift * 2)) + static_cast<int>(attr.act_gate) +
(static_cast<int>(attr.act_cand) << act_type_shift);
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
struct KernelKey {
struct Hash {
size_t operator()(const KernelKey& key) const {
int place = key.place_.which(); // less than 2^8
int type = static_cast<int>(key.type_) << 8; // less than 2^(32-8)
std::hash<int> hasher;
return hasher(place + type);
}
};
KernelType type_;
platform::Place place_;
KernelKey(KernelType type, platform::Place place)
: type_(type), place_(place) {}
size_t hash_key() const { return Hash()(*this); }
bool operator==(const KernelKey& o) const {
return platform::places_are_same_class(place_, o.place_) &&
type_ == o.type_;
}
bool operator!=(const KernelKey& o) const { return !(*this == o); }
};
// Every JitCode should have a method to get the key from attribution
template <typename Attr>
size_t JitCodeKey(const Attr& attr);
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
namespace jit {
JitCodeCreatorPool& JitCodeCreatorPool::Instance() {
static JitCodeCreatorPool g_creator_pool;
return g_creator_pool;
}
KernelPool& KernelPool::Instance() {
static KernelPool g_kernel_pool;
return g_kernel_pool;
}
ReferKernelPool& ReferKernelPool::Instance() {
static ReferKernelPool g_refer_kernel_pool;
return g_refer_kernel_pool;
}
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory> // for unique_ptr
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/jit/gen_base.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_key.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace jit {
template <KernelType KT>
class JitCodePool {
typedef std::unique_ptr<GenBase> GenBasePtr;
typedef std::unordered_map<size_t, GenBasePtr> JitCodeMap;
public:
JitCodePool() = default;
static JitCodePool& Instance() {
static thread_local JitCodePool<KT> g_jit_codes;
return g_jit_codes;
}
const JitCodeMap& AllKernels() { return codes_; }
bool Has(size_t key) const { return codes_.find(key) != codes_.end(); }
void Insert(size_t key, GenBasePtr value) {
codes_.emplace(key, std::move(value));
}
private:
JitCodeMap codes_;
DISABLE_COPY_AND_ASSIGN(JitCodePool);
};
class JitCodeCreatorPool {
typedef std::unique_ptr<const GenCreator> GenCreatorPtr;
typedef std::unordered_map<KernelKey, std::vector<GenCreatorPtr>,
KernelKey::Hash>
GenCreatorPtrMap;
public:
JitCodeCreatorPool() = default;
static JitCodeCreatorPool& Instance();
GenCreatorPtrMap& AllCreators() { return creators_; }
void Insert(const KernelKey& key, GenCreatorPtr value) {
if (creators_.find(key) == creators_.end()) {
creators_.emplace(key, std::vector<GenCreatorPtr>());
}
creators_.at(key).emplace_back(std::move(value));
}
private:
GenCreatorPtrMap creators_;
DISABLE_COPY_AND_ASSIGN(JitCodeCreatorPool);
};
typedef std::unique_ptr<const Kernel> KernelPtr;
typedef std::unordered_map<KernelKey, std::vector<KernelPtr>, KernelKey::Hash>
KernelMap;
class KernelPool {
public:
static KernelPool& Instance();
KernelPool() = default;
KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>());
}
pool_.at(key).emplace_back(std::move(value));
}
private:
KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(KernelPool);
};
// Every kernel should have refer code and it should be used in unit tests,
// so refer kernels should have it's independent kernel pool
class ReferKernelPool {
public:
static ReferKernelPool& Instance();
ReferKernelPool() = default;
KernelMap& AllKernels() { return pool_; }
void Insert(const KernelKey& key, KernelPtr value) {
if (pool_.find(key) == pool_.end()) {
pool_.emplace(key, std::vector<KernelPtr>());
}
pool_.at(key).emplace_back(std::move(value));
}
private:
KernelMap pool_;
DISABLE_COPY_AND_ASSIGN(ReferKernelPool);
};
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
namespace paddle {
namespace operators {
namespace jit {
#define SIGMOID_THRESHOLD_MIN -40.0
#define SIGMOID_THRESHOLD_MAX 13.0
#define EXP_MAX_INPUT 40.0
#define XMM_FLOAT_BLOCK 4
#define YMM_FLOAT_BLOCK 8
#define ZMM_FLOAT_BLOCK 16
} // namespace jit
} // namespace operators
} // namespace paddle
function(USE_JITKERNEL_MORE TARGET TYPE)
file(APPEND ${jit_file} "USE_JITKERNEL_MORE(${TARGET} ${TYPE});\n")
endfunction()
if(WITH_MKLML)
add_subdirectory(mkl)
endif()
if(WITH_AVX)
add_subdirectory(intrinsic)
endif()
# mix should be last
add_subdirectory(mix)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} PARENT_SCOPE)
file(GLOB jit_kernel_cc_intrinsic RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_intrinsic SRCS ${jit_kernel_cc_intrinsic} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_intrinsic PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kCRFDecoding, intrinsic)
USE_JITKERNEL_MORE(kLayerNorm, intrinsic)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/crf_decoding.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
// Note: intrinsic code is not runtime build.
// For example, if you build code on AVX, and run on AVX512 it can only use AVX
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num) {
#ifdef __AVX512F__
const int step_size = ZMM_FLOAT_BLOCK;
#else
const int step_size = YMM_FLOAT_BLOCK;
#endif
const int end = tag_num / step_size;
const int rest = tag_num % step_size;
/* Setup the alpha initial value.*/
int i_offset = 0;
int last_offset = rest - step_size;
for (int i = 0; i <= end; ++i) {
#ifdef __AVX512F__
// Declare the variable for the content of weights, input and alpha values.
__m512 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm512_loadu_ps(w + i_offset);
x_content = _mm512_loadu_ps(x + i_offset);
alpha_content = _mm512_add_ps(w_content, x_content);
// Save the alpha value.
_mm512_storeu_ps(alpha_value + i_offset, alpha_content);
#else
// AVX or AVX2
// weights, input and alpha values.
__m256 w_content, x_content, alpha_content;
// Load the relevant data into the variables from un-aligned address.
w_content = _mm256_loadu_ps(w + i_offset);
x_content = _mm256_loadu_ps(x + i_offset);
alpha_content = _mm256_add_ps(w_content, x_content);
_mm256_storeu_ps(alpha + i_offset, alpha_content);
#endif
i_offset += step_size;
if (i == end - 1) {
if (rest > 0) {
i_offset += last_offset;
} else {
break;
}
}
}
// Use the column-major strategy to get the location of maximum score.
int seq_offset = 0;
constexpr int state_trans_base_idx = 2;
for (int k = 1; k < seq_len; ++k) {
int j_offset = 0;
for (int j = 0; j <= end; ++j) {
/* Initialize the variables of maximum score and location.*/
#ifdef __AVX512F__
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max());
__m512i max_j = _mm512_setzero_si512();
#else
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max());
__m256i max_j = _mm256_set1_epi32(0);
#endif
/* Calculate the offset of transition_weights.*/
int trans_offset = state_trans_base_idx * tag_num + j_offset;
for (int i = 0; i < tag_num; ++i) {
/* Initalize the content of alpha variable with related offset.*/
#ifdef __AVX512F__
__m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i));
/* Obtain the content of weights from un-aligned address.*/
__m512 w_content = _mm512_loadu_ps(w + trans_offset);
__m512 score_v = _mm512_add_ps(alpha_content, w_content);
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS);
/* AVX512 instructions.*/
max_j = _mm512_mask_set1_epi32(max_j, mask, i);
/* Update the max_score value.*/
max_score = _mm512_max_ps(max_score, score_v);
#else
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i);
/* Obtain the content of weights from un-aligned address.*/
__m256 w_content = _mm256_loadu_ps(w + trans_offset);
__m256 score_v = _mm256_add_ps(alpha_content, w_content);
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS);
/* According to the mask value, update the index of the max_score.*/
#ifdef __AVX2__
max_j = _mm256_or_si256(
_mm256_andnot_si256((__m256i)mask, max_j),
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i)));
#else
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0);
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1);
__m128i lo_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 0); // NOLINT
__m128i hi_mask =
_mm256_extractf128_si256(*(__m256i*)&mask, 1); // NOLINT
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j);
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j);
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i));
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i));
lo_max_j = _mm_or_si128(lo_mask, lo_max_j);
hi_max_j = _mm_or_si128(hi_mask, hi_max_j);
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0);
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1);
#endif
/* Update the max_score value.*/
max_score = _mm256_max_ps(max_score, score_v);
#endif
trans_offset += tag_num;
}
/* Update the alpha and track values. */
#ifdef __AVX512F__
__m512 x_content =
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset);
max_score = _mm512_add_ps(max_score, x_content);
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score);
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset +
this->num_ + j_offset),
max_j);
#else
__m256 x_content = _mm256_loadu_ps(x + seq_offset + tag_num + j_offset);
max_score = _mm256_add_ps(max_score, x_content);
_mm256_storeu_ps(alpha + seq_offset + tag_num + j_offset, max_score);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(track + seq_offset + tag_num + j_offset),
max_j);
#endif
/* Calculate the offset of next step*/
j_offset += step_size;
if (j == end - 1) {
if (rest > 0) {
j_offset += last_offset;
} else {
break;
}
}
}
seq_offset += tag_num;
}
}
bool CRFDecodingKernel::UseMe(const int& d) const {
#ifdef __AVX512F__
constexpr int block = ZMM_FLOAT_BLOCK;
#else
constexpr int block = YMM_FLOAT_BLOCK;
#endif
return platform::MayIUse(platform::avx) && d >= block;
}
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace intrinsic = paddle::operators::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(kCRFDecoding, intrinsic, intrinsic::CRFDecodingKernel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
void CRFDecoding(const int seq_len, const float* x, const float* w,
float* alpha, int* track, int tag_num);
class CRFDecodingKernel : public KernelMore<CRFDecodingTuples<float>> {
public:
CRFDecodingKernel() { this->func = CRFDecoding; }
bool UseMe(
const typename CRFDecodingTuples<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/intrinsic/layer_norm.h"
#include <limits>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right) {
__m256 sum;
__m256 mean_vec, var_vec;
__m128 hi, lo;
__m256 tmp;
size_t offset;
size_t j;
int block = YMM_FLOAT_BLOCK;
const int rest = right % block;
const int end = right - rest;
__m256 reverse_num_vec =
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(right));
__m256 epsilon_vec = _mm256_set1_ps(epsilon);
int rest_mask =
((-1) & (~((~0U) >> (sizeof(int) * 8 - (block - rest))))) & 0x0ff;
__m256i mask_vec = _mm256_set_epi32(
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0,
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0,
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0,
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0);
for (int i = 0; i < height; ++i) {
offset = i * right;
/* get mean */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j));
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)x + j);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
mean_vec = _mm256_mul_ps(sum, reverse_num_vec);
mean[i] = *reinterpret_cast<float*>(&mean_vec);
/* get variance */
sum = _mm256_setzero_ps();
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
sum = _mm256_add_ps(sum, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_mul_ps(tmp, tmp);
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp,
*(__m256*)&mask_vec); // NOLINT
sum = _mm256_add_ps(sum, tmp);
}
hi = _mm256_extractf128_ps(sum, 1);
lo = _mm256_extractf128_ps(sum, 0);
sum = _mm256_add_ps(
sum, _mm256_insertf128_ps(
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1));
sum = _mm256_hadd_ps(sum, sum);
sum = _mm256_hadd_ps(sum, sum);
var_vec = _mm256_mul_ps(sum, reverse_num_vec);
var[i] = *reinterpret_cast<float*>(&var_vec);
/* get x_norm and calculate output*/
for (j = offset; j < end + offset; j += block) {
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec);
tmp = _mm256_div_ps(tmp,
_mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec)));
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp);
}
if (scale) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)scale + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_mul_ps(tmp,
_mm256_loadu_ps((const float*)scale + j - offset)));
}
}
if (bias) {
if (rest != 0) {
j = offset + right - block;
tmp = _mm256_loadu_ps((const float*)out + j);
}
for (j = offset; j < end + offset; j += block) {
_mm256_storeu_ps(
reinterpret_cast<float*>(out) + j,
_mm256_add_ps(_mm256_loadu_ps((const float*)out + j),
_mm256_loadu_ps((const float*)bias + j - offset)));
}
if (rest != 0) {
j = offset + right - block;
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j,
_mm256_add_ps(tmp, _mm256_loadu_ps((const float*)bias +
j - offset)));
}
}
}
}
bool LayerNormKernel::UseMe(const int& d) const {
return platform::MayIUse(platform::avx) && d >= YMM_FLOAT_BLOCK;
}
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace intrinsic = paddle::operators::jit::more::intrinsic;
REGISTER_JITKERNEL_MORE(kLayerNorm, intrinsic, intrinsic::LayerNormKernel);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace intrinsic {
void LayerNorm(float* x, float* out, float* mean, float* var,
const float* scale, const float* bias, int height,
const float epsilon, int right);
class LayerNormKernel : public KernelMore<LayerNormTuples<float>> {
public:
LayerNormKernel() { this->func = LayerNorm; }
bool UseMe(const typename LayerNormTuples<float>::attr_type&) const override;
const char* ImplType() const override { return "Intrinsic"; }
};
} // namespace intrinsic
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
file(GLOB jit_kernel_mix_cc RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.cc")
cc_library(jit_kernel_mix SRCS ${jit_kernel_mix_cc} DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_mix PARENT_SCOPE)
USE_JITKERNEL_MORE(kVSigmoid, mix)
USE_JITKERNEL_MORE(kVTanh, mix)
USE_JITKERNEL_MORE(kLSTMCtHt, mix)
USE_JITKERNEL_MORE(kLSTMC1H1, mix)
USE_JITKERNEL_MORE(kGRUH1, mix)
USE_JITKERNEL_MORE(kGRUHtPart1, mix)
USE_JITKERNEL_MORE(kGRUHtPart2, mix)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mix/mix.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mix {
void VSigmoid(const T* x, T* y, int n) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
auto compute = Get<KernelType::kVExp, XYNTuples<T>, platform::CPUPlace>(n);
compute(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
void VTanh(const T* x, T* y, int n) {
const T a = 2, b = -1;
auto compute_scal = Get<kVScal, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_addbias = Get<kVAddBias, AXYNTuples<T>, platform::CPUPlace>(n);
auto compute_sigmoid = Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(n);
compute_scal(&a, x, y, n);
compute_sigmoid(y, y, n);
compute_scal(&a, y, y, n);
compute_addbias(&b, y, y, n);
}
void (*getActFunc(KernelType type, int d))(const T*, T*, int) { // NOLINT
if (type == kVSigmoid) {
return Get<kVSigmoid, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == kVRelu) {
return Get<kVRelu, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == kVTanh) {
return Get<kVTanh, XYNTuples<T>, platform::CPUPlace>(d);
} else if (type == kVIdentity) {
return Get<kVIdentity, XYNTuples<T>, platform::CPUPlace>(d);
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
const T* ct_1 = reinterpret_cast<const T*>(step->ct_1);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
const T* wp = reinterpret_cast<const T*>(step->wp);
T* checked = reinterpret_cast<T*>(step->checked);
const int d = attr->d;
const int d2 = d * 2;
const int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d2 = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d2);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_gate_d2 = getActFunc(attr->act_gate, d2);
auto act_gate_d3 = getActFunc(attr->act_gate, d3);
auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d);
if (attr->use_peephole) {
vmul_d(wp, ct_1, checked, d);
vmul_d(wp + d, ct_1, checked + d, d);
vadd_d2(checked, gates + d, gates + d, d2);
act_gate_d2(gates + d, gates + d, d2);
} else {
act_gate_d3(gates + d, gates + d, d3);
}
// C_t = C_t-1 * fgated + cand_gated * igated
act_cand_d(gates, gates, d);
vmul_d(gates, gates + d, gates + d, d);
vmul_d(ct_1, gates + d2, gates + d2, d);
vadd_d(gates + d, gates + d2, ct, d);
if (attr->use_peephole) {
// get ogated
vmul_d(wp + d2, ct, gates + d, d);
vadd_d(gates + d, gates + d3, gates + d3, d);
act_gate_d(gates + d3, gates + d3, d);
}
// H_t = act_cell(C_t) * ogated
act_cell_d(ct, gates + d2, d);
vmul_d(gates + d2, gates + d3, ht, d);
}
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ct = reinterpret_cast<T*>(step->ct);
T* ht = reinterpret_cast<T*>(step->ht);
int d = attr->d;
int d2 = d * 2;
int d3 = d * 3;
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
auto vadd_d = Get<kVAdd, XYZNTuples<T>, platform::CPUPlace>(d);
auto act_gate_d = getActFunc(attr->act_gate, d);
auto act_cand_d = getActFunc(attr->act_cand, d);
auto act_cell_d = getActFunc(attr->act_cell, d);
/* C_t = igated * cgated*/
act_gate_d(gates + d, gates + d, d);
act_cand_d(gates, gates, d);
vmul_d(gates, gates + d, ct, d);
if (attr->use_peephole) {
// get outgated, put W_oc * C_t on igated
const T* wp = reinterpret_cast<const T*>(step->wp);
vmul_d(wp + d2, ct, gates + d, d);
vadd_d(gates + d, gates + d3, gates + d3, d);
}
/* H_t = act_cell(C_t) * ogated */
act_gate_d(gates + d3, gates + d3, d);
act_cell_d(ct, gates + d2, d);
vmul_d(gates + d2, gates + d3, ht, d);
}
// compute h1 without h0
void GRUH1(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
int d = attr->d;
int d2 = d * 2;
auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(d);
act_gate(gates, gates, d);
act_cand(gates + d2, gates + d2, d);
vmul_d(gates, gates + d2, ht, d);
}
// compute the first part of GRU: ht = act_gate(r) * ht_1
void GRUHtPart1(gru_t* step, const gru_attr_t* attr) {
// W: {W_update, W_reset; W_state}
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
auto act_gate = getActFunc(attr->act_gate, attr->d);
auto vmul_d = Get<kVMul, XYZNTuples<T>, platform::CPUPlace>(attr->d);
act_gate(gates + attr->d, gates + attr->d, attr->d);
vmul_d(ht_1, gates + attr->d, ht, attr->d);
}
// compute the second part of GRU:
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
T* gates = reinterpret_cast<T*>(step->gates);
T* ht = reinterpret_cast<T*>(step->ht);
const T* ht_1 = reinterpret_cast<const T*>(step->ht_1);
int d = attr->d;
auto act_gate = getActFunc(attr->act_gate, d);
auto act_cand = getActFunc(attr->act_cand, d);
T* y = gates + d * 2;
act_gate(gates, gates, d);
act_cand(y, y, d);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
// TODO(TJ): tuning me
bool VSigmoidKernel::UseMe(const int& d) const { return true; }
bool VTanhKernel::UseMe(const int& d) const { return true; }
bool LSTMCtHtKernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool LSTMC1H1Kernel::UseMe(const lstm_attr_t& attr) const { return true; }
bool GRUH1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart1Kernel::UseMe(const gru_attr_t& attr) const { return true; }
bool GRUHtPart2Kernel::UseMe(const gru_attr_t& attr) const { return true; }
} // namespace mix
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace mix = paddle::operators::jit::more::mix;
#define REGISTER_MORE_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mix, mix::func##Kernel)
REGISTER_MORE_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MORE_KERNEL(kVTanh, VTanh);
REGISTER_MORE_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_MORE_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_MORE_KERNEL(kGRUH1, GRUH1);
REGISTER_MORE_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_MORE_KERNEL(kGRUHtPart2, GRUHtPart2);
#undef REGISTER_MORE_KERNEL
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mix {
using T = float;
void VSigmoid(const T* x, T* y, int n);
void VTanh(const T* x, T* y, int n);
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr);
void LSTMC1H1(lstm_t* step, const lstm_attr_t* attr);
void GRUH1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart1(gru_t* step, const gru_attr_t* attr);
void GRUHtPart2(gru_t* step, const gru_attr_t* attr);
#define DECLARE_MORE_KERNEL(name, tuples) \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
name##Kernel() { this->func = name; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "Mixed"; } \
}
// XYN
DECLARE_MORE_KERNEL(VSigmoid, XYNTuples);
DECLARE_MORE_KERNEL(VTanh, XYNTuples);
DECLARE_MORE_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_MORE_KERNEL(LSTMC1H1, LSTMTuples);
DECLARE_MORE_KERNEL(GRUH1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_MORE_KERNEL(GRUHtPart2, GRUTuples);
#undef DECLARE_MORE_KERNEL
} // namespace mix
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
cc_library(jit_kernel_mkl SRCS mkl.cc DEPS jit_kernel_base dynload_mklml)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} dynload_mklml jit_kernel_mkl PARENT_SCOPE)
# use mkl kernels by name and type
USE_JITKERNEL_MORE(kVMul, mkl)
USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/more/mkl/mkl.h"
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/dynload/mklml.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mkl {
template <>
void VMul<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
}
template <>
void VMul<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
template <>
void VAdd<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAdd<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <>
void VScal<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
refer::VScal<float>(a, x, y, n);
}
}
template <>
void VScal<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
refer::VScal<double>(a, x, y, n);
}
}
template <>
void VExp<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
}
template <>
void VExp<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
// TODO(TJ): tuning me carefully on AVX, AVX2 and AVX512
template <>
bool VMulKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VAddKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VScalKernel<float>::UseMe(const int& d) const {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VExpKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7;
}
template <>
bool VTanhKernel<float>::UseMe(const int& d) const {
return d > 7;
}
#define AWALYS_USE_ME_WITH_DOUBLE(func) \
template <> \
bool func##Kernel<double>::UseMe(const int& d) const { \
return true; \
}
AWALYS_USE_ME_WITH_DOUBLE(VMul);
AWALYS_USE_ME_WITH_DOUBLE(VAdd);
AWALYS_USE_ME_WITH_DOUBLE(VScal);
AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh);
#undef AWALYS_USE_ME_WITH_DOUBLE
} // namespace mkl
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
namespace mkl = paddle::operators::jit::more::mkl;
#define REGISTER_MKL_KERNEL(key, func) \
REGISTER_JITKERNEL_MORE(key, mkl, mkl::func##Kernel<float>, \
mkl::func##Kernel<double>)
REGISTER_MKL_KERNEL(kVMul, VMul);
REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
#undef REGISTER_MKL_KERNEL
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
namespace paddle {
namespace operators {
namespace jit {
namespace more {
namespace mkl {
template <typename T>
void VMul(const T* x, const T* y, T* z, int n);
template <typename T>
void VAdd(const T* x, const T* y, T* z, int n);
template <typename T>
void VScal(const T* a, const T* x, T* y, int n);
template <typename T>
void VExp(const T* x, T* y, int n);
template <typename T>
void VSigmoid(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
VExp(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
template <typename T>
void VTanh(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoid(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#define DECLARE_MKL_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public KernelMore<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
bool UseMe(const typename tuples<T>::attr_type&) const override; \
const char* ImplType() const override { return "MKL"; } \
}
// XYZN
DECLARE_MKL_KERNEL(VMul, XYZNTuples);
DECLARE_MKL_KERNEL(VAdd, XYZNTuples);
// AXYN
DECLARE_MKL_KERNEL(VScal, AXYNTuples);
// XYN
DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples);
#undef DECLARE_MKL_KERNEL
} // namespace mkl
} // namespace more
} // namespace jit
} // namespace operators
} // namespace paddle
cc_library(jit_kernel_refer SRCS refer.cc DEPS jit_kernel_base)
set(JIT_KERNEL_DEPS ${JIT_KERNEL_DEPS} jit_kernel_refer PARENT_SCOPE)
function(USE_JITKERNEL_REFER TARGET)
file(APPEND ${jit_file} "USE_JITKERNEL_REFER(${TARGET});\n")
endfunction()
# use refer kernel by name
USE_JITKERNEL_REFER(kVMul)
USE_JITKERNEL_REFER(kVAdd)
USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVRelu)
USE_JITKERNEL_REFER(kVIdentity)
USE_JITKERNEL_REFER(kVExp)
USE_JITKERNEL_REFER(kVSigmoid)
USE_JITKERNEL_REFER(kVTanh)
USE_JITKERNEL_REFER(kLSTMCtHt)
USE_JITKERNEL_REFER(kLSTMC1H1)
USE_JITKERNEL_REFER(kGRUH1)
USE_JITKERNEL_REFER(kGRUHtPart1)
USE_JITKERNEL_REFER(kGRUHtPart2)
USE_JITKERNEL_REFER(kCRFDecoding)
USE_JITKERNEL_REFER(kLayerNorm)
USE_JITKERNEL_REFER(kNCHW16CMulNC)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/refer/refer.h"
#include "paddle/fluid/operators/jit/registry.h"
namespace refer = paddle::operators::jit::refer;
#define REGISTER_REFER_KERNEL(key, func) \
REGISTER_JITKERNEL_REFER(key, refer::func##Kernel<float>, \
refer::func##Kernel<double>)
REGISTER_REFER_KERNEL(kVMul, VMul);
REGISTER_REFER_KERNEL(kVAdd, VAdd);
REGISTER_REFER_KERNEL(kVAddRelu, VAddRelu);
REGISTER_REFER_KERNEL(kVSub, VSub);
REGISTER_REFER_KERNEL(kVScal, VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVExp, VExp);
REGISTER_REFER_KERNEL(kVSigmoid, VSigmoid);
REGISTER_REFER_KERNEL(kVTanh, VTanh);
REGISTER_REFER_KERNEL(kLSTMCtHt, LSTMCtHt);
REGISTER_REFER_KERNEL(kLSTMC1H1, LSTMC1H1);
REGISTER_REFER_KERNEL(kGRUH1, GRUH1);
REGISTER_REFER_KERNEL(kGRUHtPart1, GRUHtPart1);
REGISTER_REFER_KERNEL(kGRUHtPart2, GRUHtPart2);
REGISTER_REFER_KERNEL(kCRFDecoding, CRFDecoding);
REGISTER_REFER_KERNEL(kLayerNorm, LayerNorm);
REGISTER_REFER_KERNEL(kNCHW16CMulNC, NCHW16CMulNC);
#undef REGISTER_REFER_KERNEL
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
You may obtain a copy of the License at * You may obtain a copy of the License at
*
http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
*
Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, * distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
limitations under the License. */ * limitations under the License. */
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <string> #include <limits>
#include "paddle/fluid/operators/math/jit_kernel_impl.h" #include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace jit {
namespace jitkernel {
namespace refer { namespace refer {
/* Refer code only focus on correctness */
// Refer code only focus on correctness
template <typename T> template <typename T>
void VMul(const T* x, const T* y, T* z, int n) { void VMul(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -47,6 +48,13 @@ void VAddRelu(const T* x, const T* y, T* z, int n) { ...@@ -47,6 +48,13 @@ void VAddRelu(const T* x, const T* y, T* z, int n) {
} }
} }
template <typename T>
void VSub(const T* x, const T* y, T* z, int n) {
for (int i = 0; i < n; ++i) {
z[i] = x[i] - y[i];
}
}
template <typename T> template <typename T>
void VScal(const T* a, const T* x, T* y, int n) { void VScal(const T* a, const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -69,7 +77,11 @@ void VRelu(const T* x, T* y, int n) { ...@@ -69,7 +77,11 @@ void VRelu(const T* x, T* y, int n) {
} }
template <typename T> template <typename T>
inline void VIdentity(const T* x, T* y, int n) {} inline void VIdentity(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = x[i];
}
}
template <typename T> template <typename T>
void VExp(const T* x, T* y, int n) { void VExp(const T* x, T* y, int n) {
...@@ -102,20 +114,22 @@ void VTanh(const T* x, T* y, int n) { ...@@ -102,20 +114,22 @@ void VTanh(const T* x, T* y, int n) {
} }
template <typename T> template <typename T>
void (*getActFunc(const std::string& type))(const T*, T*, int) { // NOLINT void (*getActFunc(KernelType type))(const T*, T*, int) { // NOLINT
if (type == "sigmoid") { if (type == kVSigmoid) {
return VSigmoid<T>; return VSigmoid<T>;
} else if (type == "relu") { } else if (type == kVRelu) {
return VRelu<T>; return VRelu<T>;
} else if (type == "tanh") { } else if (type == kVTanh) {
return VTanh<T>; return VTanh<T>;
} else if (type == "identity" || type == "") { } else if (type == kVIdentity) {
return VIdentity<T>; return VIdentity<T>;
} }
PADDLE_THROW("Not support type: %s", type); PADDLE_THROW("Not support type: %s", type);
return nullptr; return nullptr;
} }
// TODO(TJ): add refer gemm and make LSTM kernels combine as same GRU kernels
// compute ct and ht // compute ct and ht
template <typename T> template <typename T>
void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) { void LSTMCtHt(lstm_t* step, const lstm_attr_t* attr) {
...@@ -231,8 +245,134 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) { ...@@ -231,8 +245,134 @@ void GRUHtPart2(gru_t* step, const gru_attr_t* attr) {
} }
} }
template <typename T>
void CRFDecoding(const int seq_len, const T* x, const T* w, T* alpha,
int* track, int right) {
constexpr int state_trans_base_idx = 2;
for (int i = 0; i < right; ++i) {
alpha[i] = w[i] + x[i];
}
for (int k = 1; k < seq_len; ++k) {
for (int i = 0; i < right; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (int j = 0; j < right; ++j) {
T score = alpha[(k - 1) * right + j] +
w[(j + state_trans_base_idx) * right + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha[k * right + i] = max_score + x[k * right + i];
track[k * right + i] = max_j;
}
}
}
template <typename T>
void LayerNorm(T* x, T* out, T* mean, T* var, const T* scale, const T* bias,
int height, const float epsilon, int right) {
// get mean
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * right;
for (int j = 0; j < right; j++) {
sum += x[offset + j];
}
mean[i] = sum / right;
}
// get variance
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * right;
for (int j = 0; j < right; j++) {
sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]);
}
var[i] = sum / right;
}
for (int i = 0; i < height; i++) {
int offset = i * right;
T sqrt_var = std::sqrt(var[i] + (T)epsilon);
for (int j = 0; j < right; j++) {
out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var;
}
}
if (scale) {
for (int i = 0; i < height; i++) {
int offset = i * right;
for (int j = 0; j < right; j++) {
out[offset + j] *= scale[j];
}
}
}
if (bias) {
for (int i = 0; i < height; i++) {
int offset = i * right;
for (int j = 0; j < right; j++) {
out[offset + j] += bias[j];
}
}
}
}
template <typename T>
void NCHW16CMulNC(const T* x, const T* y, T* z, int height, int width) {
int offset = 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int i = 0; i < 16; ++i) {
z[i + offset] = y[i] * x[i + offset];
}
offset += ZMM_FLOAT_BLOCK;
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
public: \
name##Kernel() { this->func = name<T>; } \
}
// const T* x, const T* y, T* z, int n
DECLARE_REFER_KERNEL(VMul, XYZNTuples);
DECLARE_REFER_KERNEL(VAdd, XYZNTuples);
DECLARE_REFER_KERNEL(VAddRelu, XYZNTuples);
DECLARE_REFER_KERNEL(VSub, XYZNTuples);
// const T* a, const T* x, T* y, int n
DECLARE_REFER_KERNEL(VScal, AXYNTuples);
DECLARE_REFER_KERNEL(VAddBias, AXYNTuples);
// const T* x, T* y, int n
DECLARE_REFER_KERNEL(VRelu, XYNTuples);
DECLARE_REFER_KERNEL(VIdentity, XYNTuples);
DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples);
// lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
DECLARE_REFER_KERNEL(LSTMC1H1, LSTMTuples);
// gru_t*, const gru_attr_t*
DECLARE_REFER_KERNEL(GRUH1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart1, GRUTuples);
DECLARE_REFER_KERNEL(GRUHtPart2, GRUTuples);
DECLARE_REFER_KERNEL(CRFDecoding, CRFDecodingTuples);
DECLARE_REFER_KERNEL(LayerNorm, LayerNormTuples);
DECLARE_REFER_KERNEL(NCHW16CMulNC, NCHW16CMulNCTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
} // namespace jitkernel } // namespace jit
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <memory>
#include <tuple>
#include <type_traits>
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/operators/jit/kernel_pool.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h" // for UNUSED
namespace paddle {
namespace operators {
namespace jit {
// make_unique is supported since c++14
template <typename T, typename... Args>
inline std::unique_ptr<T> make_unique(Args&&... args) {
static_assert(!std::is_array<T>::value, "T must not be array");
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
template <typename Pool, typename PlaceType, bool IsEnd, size_t I,
typename... KernelImpls>
struct JitKernelRegistrarFunctor;
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, true, I, KernelImpls...> {
void operator()(KernelType kt) const {}
};
template <typename Pool, typename PlaceType, size_t I, typename... KernelImpls>
struct JitKernelRegistrarFunctor<Pool, PlaceType, false, I, KernelImpls...> {
using KERNEL_IMPL_TYPE =
typename std::tuple_element<I, std::tuple<KernelImpls...>>::type;
void operator()(KernelType kt) const {
KernelKey kkey(kt, PlaceType());
Pool().Instance().Insert(kkey,
std::move(make_unique<const KERNEL_IMPL_TYPE>()));
constexpr auto size = std::tuple_size<std::tuple<KernelImpls...>>::value;
JitKernelRegistrarFunctor<Pool, PlaceType, I + 1 == size, I + 1,
KernelImpls...>
func;
func(kt);
}
};
template <typename Pool, typename PlaceType, typename... KernelImpls>
class JitKernelRegistrar {
public:
explicit JitKernelRegistrar(KernelType kt) {
JitKernelRegistrarFunctor<Pool, PlaceType, false, 0, KernelImpls...> func;
func(kt);
}
void Touch() {}
};
#define STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Refer always on CPUPlace
#define REGISTER_JITKERNEL_REFER(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace, \
"REGISTER_KERNEL_REFER must be called in global namespace"); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::ReferKernelPool, ::paddle::platform::CPUPlace, \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_() { \
__jit_kernel_registrar_##kernel_type##_refer_CPUPlace_.Touch(); \
return 0; \
}
// kernel_type: should be in paddle::operators::jit::KernelType
// place_type: should be one of CPUPlace and GPUPlace in paddle::platform
#define REGISTER_KERNEL_MORE(kernel_type, impl_type, place_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type, \
"REGISTER_KERNEL_MORE must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_##kernel_type##_##impl_type##_##place_type##_has_refer_ \
UNUSED = TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::KernelPool, ::paddle::platform::place_type, \
__VA_ARGS__> \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_() { \
__jit_kernel_registrar_##kernel_type##_##impl_type##_##place_type##_ \
.Touch(); \
return 0; \
}
#define REGISTER_JITKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, CPUPlace, __VA_ARGS__)
#define REGISTER_GPUKERNEL_MORE(kernel_type, impl_type, ...) \
REGISTER_KERNEL_MORE(kernel_type, impl_type, GPUPlace, __VA_ARGS__)
#define REGISTER_JITKERNEL_GEN(kernel_type, ...) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"REGISTER_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int __assert_gen_##kernel_type##_has_refer_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static ::paddle::operators::jit::JitKernelRegistrar< \
::paddle::operators::jit::JitCodeCreatorPool, \
::paddle::platform::CPUPlace, __VA_ARGS__> \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_( \
::paddle::operators::jit::KernelType::kernel_type); \
int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_() { \
__jit_kernel_registrar_gen_##kernel_type##_CPUPlace_.Touch(); \
return 0; \
}
#define USE_JITKERNEL_GEN(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_gen_##kernel_type##_CPUPlace_, \
"USE_JITKERNEL_GEN must be called in global namespace"); \
extern int TouchJitKernelReg_gen_##kernel_type##_CPUPlace_(); \
static int use_jitkernel_gen_##kernel_type##_CPUPlace_ UNUSED = \
TouchJitKernelReg_gen_##kernel_type##_CPUPlace_()
#define USE_JITKERNEL_REFER(kernel_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_refer_CPUPlace_, \
"USE_JITKERNEL_REFER must be called in global namespace"); \
extern int TouchJitKernelReg_##kernel_type##_refer_CPUPlace_(); \
static int use_jitkernel_##kernel_type##_refer_CPUPlace_ UNUSED = \
TouchJitKernelReg_##kernel_type##_refer_CPUPlace_()
#define USE_KERNEL_MORE(kernel_type, impl_type, place_type) \
STATIC_ASSERT_JITKERNEL_GLOBAL_NAMESPACE( \
__reg_jitkernel_##kernel_type##_##impl_type##_##place_type##_, \
"USE_JITKERNEL_MORE must be called in global namespace"); \
extern int \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_(); \
static int use_jitkernel_##kernel_type##_##impl_type##_##place_type##_ \
UNUSED = \
TouchJitKernelReg_##kernel_type##_##impl_type##_##place_type##_()
#define USE_JITKERNEL_MORE(kernel_type, impl_type) \
USE_KERNEL_MORE(kernel_type, impl_type, CPUPlace)
} // namespace jit
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/place.h"
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
template <typename T>
void ExpectEQ(const T* target, const T* refer, int n) {
if (std::is_floating_point<T>::value) {
for (int i = 0; i < n; ++i) {
EXPECT_NEAR(target[i], refer[i], 1e-5);
}
} else {
for (int i = 0; i < n; ++i) {
EXPECT_EQ(target[i], refer[i]);
}
}
}
std::vector<int> TestSizes() {
std::vector<int> s;
for (int i = 1; i < 32; ++i) {
s.push_back(i);
}
// test some large size
s.push_back(100);
s.push_back(1000);
s.push_back(2000);
return s;
}
namespace jit = paddle::operators::jit;
template <typename KernelTuples, typename... Args>
struct TestFuncWithRefer {
void operator()(const typename KernelTuples::func_type tgt, Args... args) {}
};
template <typename T>
struct TestFuncWithRefer<jit::XYZNTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::XYZNTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& y,
const std::vector<T>& zref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(zref.size(), x.size());
EXPECT_EQ(zref.size(), y.size());
const T* x_data = x.data();
const T* y_data = y.data();
const T* zref_data = zref.data();
const int d = zref.size();
std::vector<T> ztgt(d);
T* ztgt_data = ztgt.data();
// test normal
tgt(x_data, y_data, ztgt_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ztgt.begin());
tgt(ztgt_data, y_data, ztgt_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
// test inplace y
std::copy(y.begin(), y.end(), ztgt.begin());
tgt(x_data, ztgt_data, ztgt_data, d);
ExpectEQ<T>(ztgt_data, zref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::AXYNTuples<T>, T, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::AXYNTuples<T>::func_type tgt, const T a,
const std::vector<T>& x, const std::vector<T>& yref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
const int d = yref.size();
std::vector<T> ytgt(d);
T* ytgt_data = ytgt.data();
// test normal
tgt(&a, x_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(&a, ytgt_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::XYNTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(yref.size(), x.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
const int d = yref.size();
std::vector<T> ytgt(d);
T* ytgt_data = ytgt.data();
// test normal
tgt(x_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
// test inplace x
std::copy(x.begin(), x.end(), ytgt.begin());
tgt(ytgt_data, ytgt_data, d);
ExpectEQ<T>(ytgt_data, yref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::LSTMTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::LSTMTuples<T>::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& wp,
const std::vector<T>& ct_1, const std::vector<T>& ct_ref,
const std::vector<T>& ht_ref,
const typename jit::LSTMTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ct_ref.size(), ht_ref.size());
EXPECT_EQ(ct_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 4 * ht_ref.size());
EXPECT_EQ(wp.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ct(ct_ref.size()), ht(ht_ref.size());
std::vector<T> checked(2 * d);
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
const T* ct_ref_data = ct_ref.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ct_data = ct.data();
T* ht_data = ht.data();
T* checked_data = checked.data();
paddle::operators::jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_data;
step.ht = ht_data;
if (attr.use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
tgt(&step, &attr);
ExpectEQ<T>(ct_data, ct_ref_data, d);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
};
template <typename T>
struct TestFuncWithRefer<jit::GRUTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>> {
void operator()(const typename jit::GRUTuples<T>::func_type tgt,
const std::vector<T>& xsrc, const std::vector<T>& ht_1,
const std::vector<T>& ht_ref,
const typename jit::GRUTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(ht_1.size(), ht_ref.size());
EXPECT_EQ(xsrc.size(), 3 * ht_ref.size());
// x could be changed after compute, so copy to save src
int d = ht_ref.size();
std::vector<T> x(xsrc.size()), ht(ht_ref.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ht_1_data = ht_1.data();
const T* ht_ref_data = ht_ref.data();
T* x_data = x.data();
T* ht_data = ht.data();
paddle::operators::jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_data;
tgt(&step, &attr);
ExpectEQ<T>(ht_data, ht_ref_data, d);
}
};
template <paddle::operators::jit::KernelType KT, typename KernelTuples,
typename PlaceType, typename... Args>
void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) {
TestFuncWithRefer<KernelTuples, Args...> test;
// test jitcode
auto jitcode = jit::GetJitCode<KT, KernelTuples, PlaceType>(attr);
if (jitcode) {
VLOG(10) << "Test Jitcode Kernel ";
test(jitcode, args...);
}
// test all impls in more
jit::KernelKey kkey(KT, PlaceType());
auto& pool = jit::KernelPool().Instance().AllKernels();
auto iter = pool.find(kkey);
if (iter != pool.end()) {
auto& impls = iter->second;
for (auto& impl : impls) {
auto i = dynamic_cast<const jit::KernelMore<KernelTuples>*>(impl.get());
if (i && i->UseMe(attr)) {
auto more = i->GetFunc();
VLOG(10) << "Test More Kernel : " << i->ImplType();
test(more, args...);
}
}
}
// test result from Get function
// VLOG(10) << "Test Get function ";
auto tgt = jit::Get<KT, KernelTuples, PlaceType>(attr);
test(tgt, args...);
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestXYZNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYZNTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), y(d), zref(d);
RandomVec<T>(d, x.data());
RandomVec<T>(d, y.data());
std::vector<T> xinp(d), yinp(d); // inplace test
std::copy(x.begin(), x.end(), xinp.begin());
std::copy(y.begin(), y.end(), yinp.begin());
const T* x_data = x.data();
const T* y_data = y.data();
T* zref_data = zref.data();
T* xinp_data = xinp.data();
T* yinp_data = yinp.data();
// test refer code inplace
ref(x_data, y_data, zref_data, d);
ref(x_data, yinp_data, yinp_data, d);
ref(xinp_data, y_data, xinp_data, d);
ExpectEQ<T>(xinp_data, zref_data, d);
ExpectEQ<T>(yinp_data, zref_data, d);
TestAllImpls<KT, jit::XYZNTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>>(d, x, y, zref);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestAXYNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::AXYNTuples<T>>();
EXPECT_TRUE(ref != nullptr);
const T a = static_cast<T>(3);
std::vector<T> x(d), yref(d);
std::vector<T> xinp(d); // inplace test
RandomVec<T>(d, x.data());
std::copy(x.begin(), x.end(), xinp.begin());
const T* x_data = x.data();
T* yref_data = yref.data();
T* xinp_data = xinp.data();
// test refer code inplace
ref(&a, x_data, yref_data, d);
ref(&a, xinp_data, xinp_data, d);
ExpectEQ<T>(xinp_data, yref_data, d);
TestAllImpls<KT, jit::AXYNTuples<T>, PlaceType, T, std::vector<T>,
std::vector<T>>(d, a, x, yref);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestXYNKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int d : TestSizes()) {
auto ref = jit::GetRefer<KT, jit::XYNTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d), yref(d);
std::vector<T> xinp(d); // inplace test
RandomVec<T>(d, x.data(), -2.f, 2.f);
std::copy(x.begin(), x.end(), xinp.begin());
const T* x_data = x.data();
T* yref_data = yref.data();
T* xinp_data = xinp.data();
// test refer code inplace
ref(x_data, yref_data, d);
ref(xinp_data, xinp_data, d);
ExpectEQ<T>(xinp_data, yref_data, d);
TestAllImpls<KT, jit::XYNTuples<T>, PlaceType, std::vector<T>,
std::vector<T>>(d, x, yref);
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestLSTMKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
for (bool use_peephole : {true, false}) {
for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) {
for (auto& act_cell : all_acts) {
const jit::lstm_attr_t attr(
d, jit::to_kerneltype(act_gate), jit::to_kerneltype(act_cand),
jit::to_kerneltype(act_cell), use_peephole);
auto ref = jit::GetRefer<KT, jit::LSTMTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f);
RandomVec<T>(3 * d, wp.data(), -2.f, 2.f);
RandomVec<T>(d, ct_1.data(), -2.f, 2.f);
// x could be changed after compute, so copy to save src
std::vector<T> x(xsrc.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ct_1_data = ct_1.data();
const T* wp_data = wp.data();
T* x_data = x.data();
T* checked_data = checked.data();
T* ct_ref_data = ct_ref.data();
T* ht_ref_data = ht_ref.data();
jit::lstm_t step;
step.gates = x_data;
step.ct_1 = ct_1_data;
step.ct = ct_ref_data;
step.ht = ht_ref_data;
if (use_peephole) {
step.wp = wp_data;
step.checked = checked_data;
}
ref(&step, &attr);
VLOG(10) << attr;
TestAllImpls<KT, jit::LSTMTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>, std::vector<T>,
std::vector<T>>(attr, xsrc, wp, ct_1, ct_ref, ht_ref,
attr);
}
}
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestGRUKernel() {
namespace jit = paddle::operators::jit;
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
std::vector<std::string> all_acts = {"sigmoid", "tanh", "relu", "identity"};
for (int d : TestSizes()) {
for (auto& act_gate : all_acts) {
for (auto& act_cand : all_acts) {
const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate),
jit::to_kerneltype(act_cand));
auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
RandomVec<T>(3 * d, xsrc.data(), -2.f, 2.f);
RandomVec<T>(d, ht_1.data(), -2.f, 2.f);
// x could be changed after compute, so copy to save src
std::vector<T> x(xsrc.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin());
const T* ht_1_data = ht_1.data();
T* x_data = x.data();
T* ht_ref_data = ht_ref.data();
jit::gru_t step;
step.gates = x_data;
step.ht_1 = ht_1_data;
step.ht = ht_ref_data;
ref(&step, &attr);
VLOG(10) << attr;
TestAllImpls<KT, jit::GRUTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, std::vector<T>>(attr, xsrc, ht_1, ht_ref,
attr);
}
}
}
}
template <paddle::operators::jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
const int n = 3, c = 16 * 4, h = 10, w = 10;
auto ref = jit::GetRefer<KT, jit::NCHW16CMulNCTuples<T>>();
EXPECT_TRUE(ref != nullptr);
int sz = n * c * h * w;
std::vector<T> x(sz), y(n * c), zref(sz);
std::vector<T> ztgt(sz), zjit(sz);
RandomVec<T>(sz, x.data(), -2.f, 2.f);
RandomVec<T>(n * c, y.data(), -2.f, 2.f);
const T* x_data = x.data();
const T* y_data = y.data();
T* zref_data = zref.data();
T* ztgt_data = ztgt.data();
T* zjit_data = zjit.data();
constexpr int simd_width = ZMM_FLOAT_BLOCK;
int C = c / simd_width;
auto tgt = jit::Get<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
auto jitcode = jit::GetJitCode<KT, jit::NCHW16CMulNCTuples<T>, PlaceType>(0);
EXPECT_TRUE(tgt != nullptr);
if (std::is_same<T, float>::value &&
paddle::platform::MayIUse(paddle::platform::avx512f)) {
EXPECT_TRUE(jitcode != nullptr);
}
for (int ni = 0; ni < n; ni++) {
for (int ci = 0; ci < C; ci++) {
auto ptr_x =
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
auto ptr_zref =
zref_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
auto ptr_ztgt =
ztgt_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
ref(ptr_x, ptr_y, ptr_zref, h, w);
tgt(ptr_x, ptr_y, ptr_ztgt, h, w);
if (jitcode) {
auto ptr_zjit =
zjit_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
jitcode(ptr_x, ptr_y, ptr_zjit, h, w);
}
}
}
ExpectEQ<T>(ztgt_data, zref_data, sz);
if (jitcode) {
ExpectEQ<T>(zjit_data, zref_data, sz);
}
}
// XYZNTuple
TEST(JITKernel, kVMul) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVMul, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVMul, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVAdd) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVAdd, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVAdd, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVAddRelu) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVAddRelu, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVAddRelu, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVSub) {
namespace jit = paddle::operators::jit;
TestXYZNKernel<jit::kVSub, float, paddle::platform::CPUPlace>();
TestXYZNKernel<jit::kVSub, double, paddle::platform::CPUPlace>();
}
// AXYNTuples
TEST(JITKernel, kVScal) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::kVScal, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::kVScal, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVAddBias) {
namespace jit = paddle::operators::jit;
TestAXYNKernel<jit::kVAddBias, float, paddle::platform::CPUPlace>();
TestAXYNKernel<jit::kVAddBias, double, paddle::platform::CPUPlace>();
}
// XYNTuples
TEST(JITKernel, kVRelu) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVRelu, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVRelu, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVIdentity) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVIdentity, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVIdentity, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVExp) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVExp, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVExp, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVSigmoid) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVSigmoid, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVSigmoid, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kVTanh) {
namespace jit = paddle::operators::jit;
TestXYNKernel<jit::kVTanh, float, paddle::platform::CPUPlace>();
TestXYNKernel<jit::kVTanh, double, paddle::platform::CPUPlace>();
}
// LSTM
TEST(JITKernel, kLSTMCtHt) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::kLSTMCtHt, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::kLSTMCtHt, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kLSTMC1H1) {
namespace jit = paddle::operators::jit;
TestLSTMKernel<jit::kLSTMC1H1, float, paddle::platform::CPUPlace>();
TestLSTMKernel<jit::kLSTMC1H1, double, paddle::platform::CPUPlace>();
}
// GRU
TEST(JITKernel, kGRUH1) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::kGRUH1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUH1, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kGRUHtPart1) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::kGRUHtPart1, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUHtPart1, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kGRUHtPart2) {
namespace jit = paddle::operators::jit;
TestGRUKernel<jit::kGRUHtPart2, float, paddle::platform::CPUPlace>();
TestGRUKernel<jit::kGRUHtPart2, double, paddle::platform::CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) {
namespace jit = paddle::operators::jit;
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float,
paddle::platform::CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double,
paddle::platform::CPUPlace>();
}
// TODO(yihua/TJ): add crf decoding and layer norm unit tests
TEST(JITKernel, pool) {
// TODO(TJ): add some test
}
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
!defined(__OSX__) !defined(__OSX__)
#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/jit/kernels.h"
#endif #endif
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
...@@ -229,12 +229,12 @@ class LayerNormKernel : public framework::OpKernel<T> { ...@@ -229,12 +229,12 @@ class LayerNormKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(scale->numel(), right); PADDLE_ENFORCE_EQ(scale->numel(), right);
PADDLE_ENFORCE_EQ(bias->numel(), right); PADDLE_ENFORCE_EQ(bias->numel(), right);
const auto& ker = math::jitkernel::KernelPool::Instance() auto ker =
.template Get<math::jitkernel::LayerNormKernel<T>>( jit::Get<jit::kLayerNorm, jit::LayerNormTuples<T>, platform::CPUPlace>(
static_cast<int>(right)); right);
ker->Compute(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(), ker(x.data<T>(), out.data<T>(), mean->data<T>(), var->data<T>(),
scale->data<T>(), bias->data<T>(), static_cast<int>(left), scale->data<T>(), bias->data<T>(), static_cast<int>(left),
static_cast<const float>(epsilon)); static_cast<const float>(epsilon), right);
#endif #endif
} }
}; };
......
...@@ -73,12 +73,3 @@ if(WITH_GPU) ...@@ -73,12 +73,3 @@ if(WITH_GPU)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc)
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
if(WITH_XBYAK)
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
list(APPEND JIT_KERNEL_DEPS xbyak)
endif()
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,22 +30,21 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -30,22 +30,21 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return; return;
} }
if (relu) { if (relu) {
const auto& vaddrelu = jitkernel::KernelPool::Instance() auto compute =
.template Get<jitkernel::VAddReluKernel<T>>(N); jit::Get<jit::kVAddRelu, jit::XYZNTuples<T>, platform::CPUPlace>(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
vaddrelu->Compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
const auto& vadd = jitkernel::KernelPool::Instance() auto compute =
.template Get<jitkernel::VAddKernel<T>>(N); jit::Get<jit::kVAdd, jit::XYZNTuples<T>, platform::CPUPlace>(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
vadd->Compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} }
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_code.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/math/jit_kernel.h" // TODO(TJ): remove me
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace gen {
using namespace platform; // NOLINT
bool VXXJitCode::init(int d, int scalar_index) {
// It's not necessary to use avx512 since it would slow down the frequency
// and this kernel is not compute bound.
return MayIUse(avx) && scalar_index >= 0 && scalar_index <= 2;
}
void VXXJitCode::generate() {
// do not need push stack, and do not need save avx512reg if do not use avx512
int offset = 0;
if (with_relu_) {
vxorps(ymm_zero, ymm_zero, ymm_zero);
}
if (scalar_index_ == 1) {
vbroadcastss(ymm_src1, ptr[param1]);
} else if (scalar_index_ == 2) {
vbroadcastss(ymm_src2, ptr[param2]);
}
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
if (scalar_index_ != 1) {
vmovups(ymm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(ymm_src2, ptr[param2 + offset]);
}
if (type_ == operand_type::mul) {
vmulps(ymm_dst, ymm_src1, ymm_src2);
} else if (type_ == operand_type::add) {
vaddps(ymm_dst, ymm_src1, ymm_src2);
}
if (with_relu_) {
vmaxps(ymm_dst, ymm_zero, ymm_dst);
}
vmovups(ptr[param3 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
if (scalar_index_ != 1) {
vmovups(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovups(xmm_src2, ptr[param2 + offset]);
}
} else if (rest >= 2) {
block = 2;
if (scalar_index_ != 1) {
vmovq(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovq(xmm_src2, ptr[param2 + offset]);
}
} else {
block = 1;
if (scalar_index_ != 1) {
vmovss(xmm_src1, ptr[param1 + offset]);
}
if (scalar_index_ != 2) {
vmovss(xmm_src2, ptr[param2 + offset]);
}
}
switch (type_) {
case operand_type::mul:
vmulps(xmm_dst, xmm_src1, xmm_src2);
break;
case operand_type::add:
vaddps(xmm_dst, xmm_src1, xmm_src2);
break;
default:
break;
}
if (with_relu_) {
vmaxps(xmm_dst, xmm_zero, xmm_dst);
}
if (rest >= 4) {
vmovups(ptr[param3 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param3 + offset], xmm_dst);
} else {
vmovss(ptr[param3 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
const float ALIGN32_BEG exp_float_consts[] ALIGN32_END = {
REPEAT_8TIMES(1.f),
REPEAT_8TIMES(2.f),
REPEAT_8TIMES(0.5f),
REPEAT_8TIMES(EXP_HIG),
REPEAT_8TIMES(EXP_LOW),
REPEAT_8TIMES(CEPHES_LOG2EF),
REPEAT_8TIMES(CEPHES_EXP_C1),
REPEAT_8TIMES(CEPHES_EXP_C2),
REPEAT_8TIMES(CEPHES_EXP_P0),
REPEAT_8TIMES(CEPHES_EXP_P1),
REPEAT_8TIMES(CEPHES_EXP_P2),
REPEAT_8TIMES(CEPHES_EXP_P3),
REPEAT_8TIMES(CEPHES_EXP_P4),
REPEAT_8TIMES(CEPHES_EXP_P5),
REPEAT_8TIMES(EXP_MAX_INPUT),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MAX),
REPEAT_8TIMES(SIGMOID_THRESHOLD_MIN)};
const int ALIGN32_BEG exp_int_0x7f[] ALIGN32_END = {REPEAT_8TIMES(0x7f)};
int ALIGN32_BEG g_tmp_mem[16] ALIGN32_END = {0};
bool VActJitCode::init(int d, operand_type type) {
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
return MayIUse(avx);
}
void VActJitCode::generate() {
int offset = 0;
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
vmovups(ymm_src, ptr[param1 + offset]);
act<ymm_t>(ymm_dst, ymm_src, type_);
vmovups(ptr[param2 + offset], ymm_dst);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
int rest = num_ % YMM_FLOAT_BLOCK;
while (rest > 0) {
int block = XMM_FLOAT_BLOCK;
if (rest >= 4) {
block = 4;
vmovups(xmm_src, ptr[param1 + offset]);
} else if (rest >= 2) {
block = 2;
vmovq(xmm_src, ptr[param1 + offset]);
} else {
block = 1;
vmovss(xmm_src, ptr[param1 + offset]);
}
act<xmm_t>(xmm_dst, xmm_src, type_);
if (rest >= 4) {
vmovups(ptr[param2 + offset], xmm_dst);
} else if (rest >= 2) {
vmovq(ptr[param2 + offset], xmm_dst);
} else {
vmovss(ptr[param2 + offset], xmm_dst);
}
offset += sizeof(float) * block;
rest -= block;
}
ret();
}
bool LSTMJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void LSTMJitCode::generate() {
if (use_peephole_) {
preCode();
}
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ct_1 = r9;
reg64_t reg_ptr_ct = r10;
reg64_t reg_ptr_ht = r11;
reg64_t reg_ptr_wp = r12;
mov(reg_ptr_gates, ptr[param1 + offsetof(lstm_t, gates)]);
mov(reg_ptr_ct_1, ptr[param1 + offsetof(lstm_t, ct_1)]);
mov(reg_ptr_ct, ptr[param1 + offsetof(lstm_t, ct)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(lstm_t, ht)]);
if (use_peephole_) {
mov(reg_ptr_wp, ptr[param1 + offsetof(lstm_t, wp)]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
/* gates: W_ch, W_ih, W_fh, W_oh */
ymm_t ymm_c = ymm_t(0);
ymm_t ymm_i = ymm_t(1);
ymm_t ymm_f = ymm_t(2);
ymm_t ymm_o = ymm_t(3);
ymm_t ymm_ct_1 = ymm_t(4);
ymm_t ymm_wp0 = ymm_t(5);
ymm_t ymm_wp1 = ymm_t(6);
ymm_t ymm_wp2 = ymm_t(7);
vmovups(ymm_c, ptr[reg_ptr_gates + offset]);
vmovups(ymm_i, ptr[reg_ptr_gates + offset + d]);
vmovups(ymm_f, ptr[reg_ptr_gates + offset + 2 * d]);
vmovups(ymm_o, ptr[reg_ptr_gates + offset + 3 * d]);
if (!compute_c1h1_) {
vmovups(ymm_ct_1, ptr[reg_ptr_ct_1 + offset]);
}
if (use_peephole_) {
vmovups(ymm_wp0, ptr[reg_ptr_wp + offset]);
vmovups(ymm_wp1, ptr[reg_ptr_wp + offset + d]);
vmovups(ymm_wp2, ptr[reg_ptr_wp + offset + 2 * d]);
}
/* C_t = act_cand(c) * act_gate(i) + C_t-1 * act_gate(f) */
// act_cand(c)
act<ymm_t>(ymm_c, ymm_c, act_cand_);
// act_gate(i) or act_gate(ct_1 * wp0 + i)
if (!compute_c1h1_ && use_peephole_) {
vmulps(ymm_wp0, ymm_ct_1, ymm_wp0);
vaddps(ymm_i, ymm_i, ymm_wp0);
}
act<ymm_t>(ymm_i, ymm_i, act_gate_);
vmulps(ymm_c, ymm_c, ymm_i);
if (!compute_c1h1_) {
// act_gate(f) or act_gate(ct_1 * wp1 + f)
if (use_peephole_) {
vmulps(ymm_wp1, ymm_ct_1, ymm_wp1);
vaddps(ymm_f, ymm_f, ymm_wp1);
}
act<ymm_t>(ymm_f, ymm_f, act_gate_);
// ct
vmulps(ymm_f, ymm_f, ymm_ct_1);
vaddps(ymm_f, ymm_f, ymm_c);
}
/* H_t = act_cell(C_t) * act_gate(o) */
// act_cell(C_t)
ymm_t ymm_ct = compute_c1h1_ ? ymm_c : ymm_f;
ymm_t ymm_tmp = ymm_i;
act<ymm_t>(ymm_tmp, ymm_ct, act_cell_);
// act_gate(o) or act_gate(ct * wp2 + o)
if (use_peephole_) {
vmulps(ymm_wp2, ymm_ct, ymm_wp2);
vaddps(ymm_o, ymm_o, ymm_wp2);
}
act<ymm_t>(ymm_o, ymm_o, act_gate_);
// ht
vmulps(ymm_o, ymm_o, ymm_tmp);
// save ct and ht
vmovups(ptr[reg_ptr_ct + offset], ymm_ct);
vmovups(ptr[reg_ptr_ht + offset], ymm_o);
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
if (use_peephole_) {
postCode();
} else {
ret();
}
}
bool GRUJitCode::init(int d) { return MayIUse(avx) && d % 8 == 0; }
void GRUJitCode::generate() {
reg64_t reg_ptr_gates = rax;
reg64_t reg_ptr_ht_1 = r9;
reg64_t reg_ptr_ht = r10;
mov(reg_ptr_gates, ptr[param1 + offsetof(gru_t, gates)]);
mov(reg_ptr_ht_1, ptr[param1 + offsetof(gru_t, ht_1)]);
mov(reg_ptr_ht, ptr[param1 + offsetof(gru_t, ht)]);
ymm_t ymm_one = ymm_t(0);
if (id_ == 2) {
reg64_t reg_ptr_tmp = r11;
mov(reg_ptr_tmp, reinterpret_cast<size_t>(exp_float_consts));
vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]);
}
int offset = 0;
int d = num_ * sizeof(float);
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
ymm_t ymm_u = ymm_t(1);
ymm_t ymm_r = ymm_t(2);
ymm_t ymm_s = ymm_t(3);
ymm_t ymm_ht_1 = ymm_t(4);
// W: {W_update, W_reset; W_state}
if (id_ == 0 || id_ == 2) {
vmovups(ymm_u, ptr[reg_ptr_gates + offset]);
vmovups(ymm_s, ptr[reg_ptr_gates + offset + 2 * d]);
}
if (id_ == 1) {
vmovups(ymm_r, ptr[reg_ptr_gates + offset + d]);
}
if (id_ == 1 || id_ == 2) {
vmovups(ymm_ht_1, ptr[reg_ptr_ht_1 + offset]);
}
if (id_ == 0) {
// ht = act_gate(u) * act_cand(s)
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_s);
} else if (id_ == 1) {
// ht = act_gate(r) * ht_1
act<ymm_t>(ymm_r, ymm_r, act_gate_);
vmulps(ymm_r, ymm_r, ymm_ht_1);
vmovups(ptr[reg_ptr_ht + offset], ymm_r);
} else if (id_ == 2) {
// ht = act_gate(u) * act_cand(s) + (1-act_gate(u)) * ht_1
ymm_t ymm_one_inner = ymm_t(ymm_one.getIdx());
act<ymm_t>(ymm_u, ymm_u, act_gate_);
act<ymm_t>(ymm_s, ymm_s, act_cand_);
vmulps(ymm_s, ymm_s, ymm_u);
vsubps(ymm_u, ymm_one_inner, ymm_u);
vmulps(ymm_u, ymm_ht_1, ymm_u);
vaddps(ymm_u, ymm_s, ymm_u);
vmovups(ptr[reg_ptr_ht + offset], ymm_u);
}
offset += sizeof(float) * YMM_FLOAT_BLOCK;
}
ret();
}
} // namespace gen
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_gen.h"
#include <fstream>
#include <iostream>
#include <sstream>
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(dump_jitcode, false, "Whether to dump the jitcode to file");
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace gen {
constexpr Xbyak::Operand::Code g_abi_regs[] = {
Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12,
Xbyak::Operand::R13, Xbyak::Operand::R14, Xbyak::Operand::R15};
constexpr int num_g_abi_regs = sizeof(g_abi_regs) / sizeof(g_abi_regs[0]);
void JitCode::preCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
push(Xbyak::Reg64(g_abi_regs[i]));
}
if (platform::MayIUse(platform::avx512f)) {
mov(reg_EVEX_max_8b_offt, 2 * EVEX_max_8b_offt);
}
}
void JitCode::postCode() {
for (int i = 0; i < num_g_abi_regs; ++i) {
pop(Xbyak::Reg64(g_abi_regs[num_g_abi_regs - 1 - i]));
}
ret();
}
void JitCode::dumpCode(const Xbyak::uint8 *code) const {
if (code) {
static int counter = 0;
std::ostringstream filename;
filename << "paddle_jitcode_" << name() << "." << counter << ".bin";
counter++;
std::ofstream fout(filename.str(), std::ios::out);
if (fout.is_open()) {
fout.write(reinterpret_cast<const char *>(code), getSize());
fout.close();
}
}
}
Xbyak::Address JitCode::EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast) {
int scale = 0;
if (EVEX_max_8b_offt <= offt && offt < 3 * EVEX_max_8b_offt) {
offt = offt - 2 * EVEX_max_8b_offt;
scale = 1;
} else if (3 * EVEX_max_8b_offt <= offt && offt < 5 * EVEX_max_8b_offt) {
offt = offt - 4 * EVEX_max_8b_offt;
scale = 2;
}
auto re = Xbyak::RegExp() + base + offt;
if (scale) {
re = re + reg_EVEX_max_8b_offt * scale;
}
if (bcast) {
return zword_b[re];
} else {
return zword[re];
}
}
} // namespace gen
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <type_traits>
#include "paddle/fluid/platform/macros.h"
#define XBYAK_USE_MMAP_ALLOCATOR
#include "xbyak/xbyak.h"
#include "xbyak/xbyak_util.h"
DECLARE_bool(dump_jitcode);
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
namespace gen {
#define DECLARE_JIT_CODE(codename) \
const char *name() const override { return #codename; }
// Application Binary Interface
constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI),
abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX),
abi_param4(Xbyak::Operand::RCX), abi_not_param1(Xbyak::Operand::RCX);
class JitCode : public Xbyak::CodeGenerator {
public:
explicit JitCode(size_t code_size = 256 * 1024, void *code_ptr = nullptr)
: Xbyak::CodeGenerator(code_size, code_ptr) {}
virtual ~JitCode() {}
virtual const char *name() const = 0;
virtual void generate() = 0;
template <typename FUNC>
const FUNC getCode() {
this->generate();
const Xbyak::uint8 *code = CodeGenerator::getCode();
if (FLAGS_dump_jitcode) {
this->dumpCode(code);
}
return reinterpret_cast<const FUNC>(code);
}
DISABLE_COPY_AND_ASSIGN(JitCode);
protected:
Xbyak::Reg64 param1{abi_param1};
const int EVEX_max_8b_offt = 0x200;
const Xbyak::Reg64 reg_EVEX_max_8b_offt = rbp;
void preCode();
void postCode();
void dumpCode(const Xbyak::uint8 *code) const;
void L(const char *label) { Xbyak::CodeGenerator::L(label); }
void L(const Xbyak::Label &label) { Xbyak::CodeGenerator::L(label); }
// Enhanced vector extension
Xbyak::Address EVEX_compress_addr(Xbyak::Reg64 base, int offt,
bool bcast = false);
};
} // namespace gen
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <memory> // for shared_ptr
#include <string>
#include <unordered_map>
#include "paddle/fluid/operators/math/jit_kernel_impl.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/macros.h"
// Note: Only support on CPU yet.
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
// TODO(TJ): remove me
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
class Kernel {
public:
Kernel() = default;
virtual ~Kernel() = default;
// TODO(TJ): below members should be deprecated.
int num_{0};
int end_{0};
int rest_{0};
DISABLE_COPY_AND_ASSIGN(Kernel);
};
class KernelPool {
public:
static KernelPool &Instance();
template <typename Ker, typename... ARGS>
std::shared_ptr<const Ker> Get(ARGS... args);
std::shared_ptr<const Kernel> Get(const std::string &key) const;
private:
KernelPool() = default;
std::unordered_map<std::string, std::shared_ptr<const Kernel>> kers_;
DISABLE_COPY_AND_ASSIGN(KernelPool);
};
template <typename T>
class VMulKernel : public Kernel {
public:
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VAddKernel : public Kernel {
public:
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VAddReluKernel : public Kernel {
public:
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VScalKernel : public Kernel {
public:
// y = a.*x
void (*Compute)(const T *, const T *, T *, int);
};
template <typename T>
class VAddBiasKernel : public Kernel {
public:
// y = a.+x
void (*Compute)(const T *, const T *, T *, int);
};
#ifdef PADDLE_WITH_MKLDNN
template <typename T>
class EltwiseMulnChw16cNCKernel : public Kernel {
public:
// nChw16c = nChw16c .* NC
void (*Compute)(const float *, const float *, float *, int, int);
};
#endif
template <typename T>
class VActKernel : public Kernel {
public:
void (*Compute)(const T *, T *, int);
};
template <typename T>
class VReluKernel : public VActKernel<T> {};
template <typename T>
class VIdentityKernel : public VActKernel<T> {};
template <typename T>
class VExpKernel : public VActKernel<T> {};
template <typename T>
class VSigmoidKernel : public VActKernel<T> {};
template <typename T>
class VTanhKernel : public VActKernel<T> {};
template <typename T>
class LSTMKernel : public Kernel {
public:
// compute c1 and h1 without c0 or h0
void (*ComputeC1H1)(lstm_t *, const lstm_attr_t *);
void (*ComputeCtHt)(lstm_t *, const lstm_attr_t *);
};
template <typename T>
class GRUKernel : public Kernel {
public:
// compute h1 without h0
void (*ComputeH1)(gru_t *, const gru_attr_t *);
void (*ComputeHtPart1)(gru_t *, const gru_attr_t *);
void (*ComputeHtPart2)(gru_t *, const gru_attr_t *);
};
template <typename T>
class CRFDecodeKernel : public Kernel {
public:
virtual void Compute(const int seq_len, const T *x, const T *w, T *alpha,
int *track) const = 0;
};
template <typename T>
class LayerNormKernel : public Kernel {
public:
virtual void Compute(T *x, T *out, T *mean, T *var, const T *scale,
const T *bias, int height,
const float epsilon) const = 0;
};
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
#ifdef PADDLE_WITH_MKLML
template <typename T>
void VMulMKL(const T* x, const T* y, T* z, int n);
template <>
void VMulMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsMul(n, x, y, z);
}
template <>
void VMulMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdMul(n, x, y, z);
}
template <typename T>
void VAddMKL(const T* x, const T* y, T* z, int n);
template <>
void VAddMKL<float>(const float* x, const float* y, float* z, int n) {
platform::dynload::vsAdd(n, x, y, z);
}
template <>
void VAddMKL<double>(const double* x, const double* y, double* z, int n) {
platform::dynload::vdAdd(n, x, y, z);
}
template <typename T>
void VScalMKL(const T* a, const T* x, T* y, int n);
template <>
void VScalMKL<float>(const float* a, const float* x, float* y, int n) {
if (x == y) {
platform::dynload::cblas_sscal(n, *a, y, 1);
} else {
refer::VScal<float>(a, x, y, n);
}
}
template <>
void VScalMKL<double>(const double* a, const double* x, double* y, int n) {
if (x == y) {
platform::dynload::cblas_dscal(n, *a, y, 1);
} else {
refer::VScal<double>(a, x, y, n);
}
}
#endif
/* VMUL JitKernel */
template <typename T>
class VMulKernelImpl : public VMulKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
// roughly estimate the size of code
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VMulMKL<T>;
return;
}
#endif
this->Compute = refer::VMul<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VMulKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VMulKernelImpl<float>::useMKL(int d) {
return platform::MayIUse(platform::avx512f) && d > 512;
}
template <>
bool VMulKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VAdd JitKernel */
template <typename T>
class VAddKernelImpl : public VAddKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VAddMKL<T>;
return;
}
#endif
this->Compute = refer::VAdd<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VAddKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VAddKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
#ifdef PADDLE_WITH_MKLDNN
/* EltwiseMul for nChw16c & NC inputs JitKernel */
template <typename T>
class EltwiseMulnChw16cNCKernelImpl
: public math::jitkernel::EltwiseMulnChw16cNCKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit EltwiseMulnChw16cNCKernelImpl(int d)
: EltwiseMulnChw16cNCKernel<T>() {
using mul_func_t = void (*)(const float*, const float*, float*, int, int);
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
// roughly estimate the size of code
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
sz = sz > 4096 ? sz : 4096;
jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz));
this->Compute = (mul_func_t)jitcode_->getCode();
return;
}
#endif
PADDLE_THROW(
"This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN "
"environemnt");
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::EltwiseMulnChw16cNC> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool EltwiseMulnChw16cNCKernelImpl<float>::useJIT(int d) {
return true;
}
#endif
#endif
/* VAddRelu JitKernel */
template <typename T>
class VAddReluKernelImpl : public VAddReluKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
this->Compute = refer::VAddRelu<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddReluKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d);
}
#endif
/* VScal JitKernel */
template <typename T>
class VScalKernelImpl : public VScalKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VScalMKL<T>;
return;
}
#endif
this->Compute = refer::VScal<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VScalKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d, 1);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VScalKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VScalKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VAddBias JitKernel */
template <typename T>
class VAddBiasKernelImpl : public VAddBiasKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
sz > 4096 ? sz : 4096));
this->Compute =
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
return;
}
#endif
this->Compute = refer::VAddBias<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VXXJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VAddBiasKernelImpl<float>::useJIT(int d) {
return gen::VXXJitCode::init(d, 1);
}
#endif
/* VRelu JitKernel */
template <typename T>
class VReluKernelImpl : public VReluKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VReluKernelImpl(int d) : VReluKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 /* init size */ +
d / YMM_FLOAT_BLOCK * 4 /* instructions */ *
8 /* average bytes for each instruction */;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
this->Compute = refer::VRelu<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VReluKernelImpl<float>::useJIT(int d) {
return gen::VActJitCode::init(d, gen::operand_type::relu);
}
#endif
/* An empty JitKernel */
template <typename T>
class VIdentityKernelImpl : public VIdentityKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VIdentityKernelImpl(int d) : VIdentityKernel<T>() {
this->Compute = refer::VIdentity<T>;
}
};
REGISTER_JITKERNEL(vmul, VMulKernel);
REGISTER_JITKERNEL(vadd, VAddKernel);
REGISTER_JITKERNEL(vaddrelu, VAddReluKernel);
REGISTER_JITKERNEL(vscal, VScalKernel);
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
REGISTER_JITKERNEL(vrelu, VReluKernel);
REGISTER_JITKERNEL(videntity, VIdentityKernel);
#ifdef PADDLE_WITH_MKLDNN
REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel);
#endif
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
/* CRF Decode JitKernel */
template <typename T, platform::cpu_isa_t isa, jit_block>
class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
public:
explicit CRFDecodeKernelImpl(int tag_num) : CRFDecodeKernel<T>() {
this->num_ = tag_num;
}
void Compute(const int seq_len, const T* x, const T* w, T* alpha,
int* track) const override {
constexpr int state_trans_base_idx = 2;
for (int i = 0; i < this->num_; ++i) {
alpha[i] = w[i] + x[i];
}
for (int k = 1; k < seq_len; ++k) {
for (int i = 0; i < this->num_; ++i) {
T max_score = -std::numeric_limits<T>::max();
int max_j = 0;
for (int j = 0; j < this->num_; ++j) {
T score = alpha[(k - 1) * this->num_ + j] +
w[(j + state_trans_base_idx) * this->num_ + i];
if (score > max_score) {
max_score = score;
max_j = j;
}
}
alpha[k * this->num_ + i] = max_score + x[k * this->num_ + i];
track[k * this->num_ + i] = max_j;
}
}
}
};
#define INIT_ALPHA(step_size) \
/* Setup the alpha initial value.*/ \
int i_offset = 0; \
int last_offset = this->rest_ - step_size; \
for (int i = 0; i <= this->end_; ++i) { \
/* weights, input and alpha values. */ \
__m256 w_content, x_content, alpha_content; \
/* Load the relevant data into the variables from un-aligned address.*/ \
w_content = _mm256_loadu_ps(w + i_offset); \
x_content = _mm256_loadu_ps(x + i_offset); \
alpha_content = _mm256_add_ps(w_content, x_content); \
_mm256_storeu_ps(alpha + i_offset, alpha_content); \
i_offset += step_size; \
if (i == this->end_ - 1) { \
if (this->rest_ > 0) { \
i_offset += last_offset; \
} else { \
break; \
} \
} \
}
#define UPDATE_ALPHA(step_size) \
/* Update the alpha and track values. */ \
__m256 x_content = _mm256_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm256_add_ps(max_score, x_content); \
_mm256_storeu_ps(alpha + seq_offset + this->num_ + j_offset, max_score); \
_mm256_storeu_si256( \
reinterpret_cast<__m256i*>(track + seq_offset + this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/ \
j_offset += step_size; \
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
} else { \
break; \
} \
}
#define INTRIAVX_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, platform::avx, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, platform::avx, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/ \
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max()); \
__m256i max_j = _mm256_set1_epi32(0); \
/* Calculate the offset of transition_weights.*/ \
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/ \
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \
/* Obtain the content of weights from un-aligned address.*/ \
__m256 w_content = _mm256_loadu_ps(w + trans_offset); \
__m256 score_v = _mm256_add_ps(alpha_content, w_content); \
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \
/* According to the mask value, update the index of the max_score.*/ \
/* AVX instructions.*/ \
__m128i lo_max_j = _mm256_extractf128_si256(max_j, 0); \
__m128i hi_max_j = _mm256_extractf128_si256(max_j, 1); \
__m128i lo_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 0); \
__m128i hi_mask = _mm256_extractf128_si256(*(__m256i*)&mask, 1); \
lo_max_j = _mm_andnot_si128(lo_mask, lo_max_j); \
hi_max_j = _mm_andnot_si128(hi_mask, hi_max_j); \
lo_mask = _mm_and_si128(lo_mask, _mm_set1_epi32(i)); \
hi_mask = _mm_and_si128(hi_mask, _mm_set1_epi32(i)); \
lo_max_j = _mm_or_si128(lo_mask, lo_max_j); \
hi_max_j = _mm_or_si128(hi_mask, hi_max_j); \
max_j = _mm256_insertf128_si256(max_j, lo_max_j, 0); \
max_j = _mm256_insertf128_si256(max_j, hi_max_j, 1); \
/* AVX done*/ \
/* Update the max_score value.*/ \
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
}
#define INTRIAVX2_FLOAT(isa, block) \
template <> \
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(YMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/ \
__m256 max_score = _mm256_set1_ps(-std::numeric_limits<float>::max()); \
__m256i max_j = _mm256_set1_epi32(0); \
/* Calculate the offset of transition_weights.*/ \
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/ \
__m256 alpha_content = _mm256_broadcast_ss(alpha + seq_offset + i); \
/* Obtain the content of weights from un-aligned address.*/ \
__m256 w_content = _mm256_loadu_ps(w + trans_offset); \
__m256 score_v = _mm256_add_ps(alpha_content, w_content); \
__m256 mask = _mm256_cmp_ps(score_v, max_score, _CMP_GT_OS); \
/* According to the mask value, update the index of the max_score.*/ \
/* AVX2 instructions.*/ \
max_j = _mm256_or_si256( \
_mm256_andnot_si256((__m256i)mask, max_j), \
_mm256_and_si256((__m256i)mask, _mm256_set1_epi32(i))); \
/* Update the max_score value.*/ \
max_score = _mm256_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
} \
seq_offset += this->num_; \
} \
}
#define INTRIAVX512_FLOAT(block) \
template <> \
CRFDecodeKernelImpl<float, platform::avx512f, block>::CRFDecodeKernelImpl( \
int tag_num) \
: CRFDecodeKernel<float>() { \
this->num_ = tag_num; \
this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \
this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
} \
template <> \
void CRFDecodeKernelImpl<float, platform::avx512f, block>::Compute( \
const int seq_len, const float* x, const float* w, float* alpha, \
int* track) const { \
INIT_ALPHA(ZMM_FLOAT_BLOCK) \
/* Use the column-major strategy to get the location of maximum score.*/ \
int seq_offset = 0; \
constexpr int state_trans_base_idx = 2; \
for (int k = 1; k < seq_len; ++k) { \
int j_offset = 0; \
for (int j = 0; j <= this->end_; ++j) { \
/* Initialize the variables of maximum score and location.*/ \
__m512 max_score = _mm512_set1_ps(-std::numeric_limits<float>::max()); \
__m512i max_j = _mm512_setzero_si512(); \
/* Calculate the offset of transition_weights.*/ \
int trans_offset = state_trans_base_idx * this->num_ + j_offset; \
for (int i = 0; i < this->num_; ++i) { \
/* Initalize the content of alpha variable with related offset.*/ \
__m512 alpha_content = _mm512_set1_ps(*(alpha + seq_offset + i)); \
/* Obtain the content of weights from un-aligned address.*/ \
__m512 w_content = _mm512_loadu_ps(w + trans_offset); \
__m512 score_v = _mm512_add_ps(alpha_content, w_content); \
__mmask16 mask = _mm512_cmp_ps_mask(score_v, max_score, _CMP_GT_OS); \
/* AVX512 instructions.*/ \
max_j = _mm512_mask_set1_epi32(max_j, mask, i); \
/* Update the max_score value.*/ \
max_score = _mm512_max_ps(max_score, score_v); \
trans_offset += this->num_; \
} \
/* Update the alpha and track values.*/ \
__m512 x_content = \
_mm512_loadu_ps(x + seq_offset + this->num_ + j_offset); \
max_score = _mm512_add_ps(max_score, x_content); \
_mm512_storeu_ps(alpha + seq_offset + this->num_ + j_offset, \
max_score); \
_mm512_storeu_si512(reinterpret_cast<__m512i*>(track + seq_offset + \
this->num_ + j_offset), \
max_j); \
/* Calculate the offset of next step*/ \
j_offset += ZMM_FLOAT_BLOCK; \
if (j == this->end_ - 1) { \
if (this->rest_ > 0) { \
j_offset += last_offset; \
} else { \
break; \
} \
} \
} \
seq_offset += this->num_; \
} \
}
#ifdef __AVX__
INTRIAVX_FLOAT(kEQ8);
INTRIAVX_FLOAT(kGT8LT16);
INTRIAVX_FLOAT(kEQ16);
INTRIAVX_FLOAT(kGT16);
#endif
#ifdef __AVX2__
INTRIAVX2_FLOAT(platform::avx2, kEQ8);
INTRIAVX2_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX2_FLOAT(platform::avx2, kEQ16);
INTRIAVX2_FLOAT(platform::avx2, kGT16);
#endif
#ifdef __AVX512F__
INTRIAVX2_FLOAT(platform::avx512f, kEQ8);
INTRIAVX2_FLOAT(platform::avx512f, kGT8LT16);
INTRIAVX512_FLOAT(kEQ16);
INTRIAVX512_FLOAT(kGT16);
#endif
#undef INTRIAVX512_FLOAT
#undef INTRIAVX2_FLOAT
#undef INTRIAVX_FLOAT
#undef INIT_ALPHA
#undef UPDATE_ALPHA
REGISTER_JITKERNEL_DEPRECATED(crf_decode, CRFDecodeKernel);
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
#ifdef PADDLE_WITH_MKLML
// try to use MKL to speedup
template <typename T>
void VExpMKL(const T* x, T* y, int n);
template <>
void VExpMKL<float>(const float* x, float* y, int n) {
platform::dynload::vsExp(n, x, y);
}
template <>
void VExpMKL<double>(const double* x, double* y, int n) {
platform::dynload::vdExp(n, x, y);
}
template <typename T>
void VSigmoidMKL(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = static_cast<T>(0) - y[i];
}
VExpMKL(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(1) / (static_cast<T>(1) + y[i]);
}
}
template <typename T>
void VTanhMKL(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * x[i];
}
VSigmoidMKL(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = static_cast<T>(2) * y[i] - static_cast<T>(1);
}
}
#endif
/* VExp JitKernel */
template <typename T>
class VExpKernelImpl : public VExpKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VExpKernelImpl(int d) : VExpKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
if (useMKL(d)) {
this->Compute = VExpMKL<T>;
return;
}
#endif
this->Compute = refer::VExp<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VExpKernelImpl<float>::useJIT(int d) {
return gen::VActJitCode::init(d, gen::operand_type::exp);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VExpKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VExpKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VSigmoid JitKernel */
template <typename T>
class VSigmoidKernelImpl : public VSigmoidKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if (useMKL(d)) {
this->Compute = VSigmoidMKL<T>;
return;
}
#endif
this->Compute = refer::VSigmoid<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VSigmoidKernelImpl<float>::useJIT(int d) {
return gen::VActJitCode::init(d, gen::operand_type::sigmoid);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VSigmoidKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VSigmoidKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
/* VTanh JitKernel */
template <typename T>
class VTanhKernelImpl : public VTanhKernel<T> {
public:
JITKERNEL_DECLARE_STATIC_FUNC;
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(d)) {
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8;
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh,
sz > 4096 ? sz : 4096));
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
return;
}
#endif
#ifdef PADDLE_WITH_MKLML
// strictly it's a better impl with MKL, then is refer
if (useMKL(d)) {
this->Compute = VTanhMKL<T>;
return;
}
#endif
this->Compute = refer::VTanh<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::VActJitCode> jitcode_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool VTanhKernelImpl<float>::useJIT(int d) {
return gen::VActJitCode::init(d, gen::operand_type::tanh);
}
#endif
#ifdef PADDLE_WITH_MKLML
template <>
bool VTanhKernelImpl<float>::useMKL(int d) {
return d > 512;
}
template <>
bool VTanhKernelImpl<double>::useMKL(int d) {
return true;
}
#endif
REGISTER_JITKERNEL(vexp, VExpKernel);
REGISTER_JITKERNEL(vsigmoid, VSigmoidKernel);
REGISTER_JITKERNEL(vtanh, VTanhKernel);
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <math.h>
#include <limits>
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
/* Layer Norm JitKernel */
template <typename T, platform::cpu_isa_t isa, jit_block>
class LayerNormKernelImpl : public LayerNormKernel<T> {
public:
explicit LayerNormKernelImpl(int right) : LayerNormKernel<T>() {
this->num_ = right;
}
void Compute(T* x, T* out, T* mean, T* var, const T* scale, const T* bias,
int height, const float epsilon) const override {
// get mean
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * this->num_;
for (int j = 0; j < this->num_; j++) {
sum += x[offset + j];
}
mean[i] = sum / this->num_;
}
// get variance
for (int i = 0; i < height; i++) {
T sum = 0.0;
int offset = i * this->num_;
for (int j = 0; j < this->num_; j++) {
sum += (x[offset + j] - mean[i]) * (x[offset + j] - mean[i]);
}
var[i] = sum / this->num_;
}
for (int i = 0; i < height; i++) {
int offset = i * this->num_;
T sqrt_var = sqrt(var[i] + (T)epsilon);
for (int j = 0; j < this->num_; j++) {
out[offset + j] = (x[offset + j] - mean[i]) / sqrt_var;
}
}
if (scale) {
for (int i = 0; i < height; i++) {
int offset = i * this->num_;
for (int j = 0; j < this->num_; j++) {
out[offset + j] *= scale[j];
}
}
}
if (bias) {
for (int i = 0; i < height; i++) {
int offset = i * this->num_;
for (int j = 0; j < this->num_; j++) {
out[offset + j] += bias[j];
}
}
}
}
};
#define INTRIAVX_FLOAT(isa, jit_block) \
template <> \
LayerNormKernelImpl<float, isa, jit_block>::LayerNormKernelImpl(int right) \
: LayerNormKernel<float>() { \
this->num_ = right; \
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
this->end_ = this->num_ - this->rest_; \
} \
template <> \
void LayerNormKernelImpl<float, isa, jit_block>::Compute( \
float* x, float* out, float* mean, float* var, const float* scale, \
const float* bias, int height, const float epsilon) const { \
__m256 sum; \
__m256 mean_vec, var_vec; \
__m128 hi, lo; \
__m256 tmp; \
size_t offset; \
size_t j; \
size_t block = YMM_FLOAT_BLOCK; \
__m256 reverse_num_vec = \
_mm256_div_ps(_mm256_set1_ps(1.0), _mm256_set1_ps(this->num_)); \
__m256 epsilon_vec = _mm256_set1_ps(epsilon); \
int rest_mask = \
((-1) & (~((~0U) >> (sizeof(int) * 8 - (YMM_FLOAT_BLOCK - rest_))))) & \
0x0ff; \
__m256i mask_vec = _mm256_set_epi32( \
rest_mask & 0x80 ? 0xffffffff : 0, rest_mask & 0x40 ? 0xffffffff : 0, \
rest_mask & 0x20 ? 0xffffffff : 0, rest_mask & 0x10 ? 0xffffffff : 0, \
rest_mask & 0x8 ? 0xffffffff : 0, rest_mask & 0x4 ? 0xffffffff : 0, \
rest_mask & 0x2 ? 0xffffffff : 0, rest_mask & 0x1 ? 0xffffffff : 0); \
\
for (int i = 0; i < height; ++i) { \
offset = i * this->num_; \
\
/* get mean */ \
sum = _mm256_setzero_ps(); \
for (j = offset; j < end_ + offset; j += block) { \
sum = _mm256_add_ps(sum, _mm256_loadu_ps((const float*)x + j)); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)x + j); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
lo = _mm256_extractf128_ps(sum, 0); \
sum = _mm256_add_ps( \
sum, _mm256_insertf128_ps( \
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
sum = _mm256_hadd_ps(sum, sum); \
sum = _mm256_hadd_ps(sum, sum); \
mean_vec = _mm256_mul_ps(sum, reverse_num_vec); \
mean[i] = *reinterpret_cast<float*>(&mean_vec); \
\
/* get variance */ \
sum = _mm256_setzero_ps(); \
for (j = offset; j < end_ + offset; j += block) { \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_mul_ps(tmp, tmp); \
sum = _mm256_add_ps(sum, tmp); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_mul_ps(tmp, tmp); \
tmp = _mm256_blendv_ps(_mm256_setzero_ps(), tmp, *(__m256*)&mask_vec); \
sum = _mm256_add_ps(sum, tmp); \
} \
hi = _mm256_extractf128_ps(sum, 1); \
lo = _mm256_extractf128_ps(sum, 0); \
sum = _mm256_add_ps( \
sum, _mm256_insertf128_ps( \
_mm256_insertf128_ps(_mm256_setzero_ps(), hi, 0), lo, 1)); \
sum = _mm256_hadd_ps(sum, sum); \
sum = _mm256_hadd_ps(sum, sum); \
var_vec = _mm256_mul_ps(sum, reverse_num_vec); \
var[i] = *reinterpret_cast<float*>(&var_vec); \
\
/* get x_norm and calculate output*/ \
for (j = offset; j < end_ + offset; j += block) { \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_div_ps( \
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
} \
if (rest_ != 0) { \
j = offset + num_ - block; \
tmp = _mm256_sub_ps(_mm256_loadu_ps((const float*)x + j), mean_vec); \
tmp = _mm256_div_ps( \
tmp, _mm256_sqrt_ps(_mm256_add_ps(var_vec, epsilon_vec))); \
_mm256_storeu_ps(reinterpret_cast<float*>(out) + j, tmp); \
} \
\
if (scale) { \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)out + j); \
} \
for (j = offset; j < end_ + offset; j += block) { \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_mul_ps( \
_mm256_loadu_ps((const float*)out + j), \
_mm256_loadu_ps((const float*)scale + j - offset))); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_mul_ps( \
tmp, _mm256_loadu_ps((const float*)scale + j - offset))); \
} \
} \
\
if (bias) { \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
tmp = _mm256_loadu_ps((const float*)out + j); \
} \
for (j = offset; j < end_ + offset; j += block) { \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_add_ps( \
_mm256_loadu_ps((const float*)out + j), \
_mm256_loadu_ps((const float*)bias + j - offset))); \
} \
if (rest_ != 0) { \
j = offset + this->num_ - block; \
_mm256_storeu_ps( \
reinterpret_cast<float*>(out) + j, \
_mm256_add_ps( \
tmp, _mm256_loadu_ps((const float*)bias + j - offset))); \
} \
} \
} \
}
#ifdef __AVX__
INTRIAVX_FLOAT(platform::avx, kEQ8);
INTRIAVX_FLOAT(platform::avx, kGT8LT16);
INTRIAVX_FLOAT(platform::avx, kEQ16);
INTRIAVX_FLOAT(platform::avx, kGT16);
INTRIAVX_FLOAT(platform::avx2, kEQ8);
INTRIAVX_FLOAT(platform::avx2, kGT8LT16);
INTRIAVX_FLOAT(platform::avx2, kEQ16);
INTRIAVX_FLOAT(platform::avx2, kGT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ8);
INTRIAVX_FLOAT(platform::avx512f, kGT8LT16);
INTRIAVX_FLOAT(platform::avx512f, kEQ16);
INTRIAVX_FLOAT(platform::avx512f, kGT16);
#endif
#undef INTRIAVX_FLOAT
REGISTER_JITKERNEL_DEPRECATED(layer_norm, LayerNormKernel);
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
#define JITKERNEL_DECLARE_STATIC_FUNC \
static inline std::string name(int d) { \
PADDLE_THROW("DType should be either float or double"); \
} \
static inline bool useJIT(int d) { return false; } \
static inline bool useMKL(int d) { return false; }
#define JITKERNEL_DEFINE_NAME(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(int d) { \
std::string key(#ker_key "f"); \
if (useJIT(d)) { \
/* only jit code need record d*/ \
return key + "jit" + std::to_string(d); \
} else if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(int d) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/ \
if (useMKL(d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, int>(int d)
#define JITKERNEL_FIND_KEY(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(d)
#define JITKERNEL_IMPL(ker_class, ker_dtype) \
p = std::dynamic_pointer_cast<ker_class<ker_dtype>>( \
std::make_shared<ker_class##Impl<ker_dtype>>(d))
#define REGISTER_JITKERNEL_WITH_DTYPE(ker_class, ker_dtype, marco_declare, \
macro_find_key, macro_impl) \
marco_declare(ker_class, ker_dtype) { \
macro_find_key(ker_class, ker_dtype); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
macro_impl(ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_ARGS(ker_key, ker_class, marco_define_name, \
marco_declare, macro_find_key, macro_impl) \
marco_define_name(ker_key, ker_class); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, float, marco_declare, \
macro_find_key, macro_impl); \
REGISTER_JITKERNEL_WITH_DTYPE(ker_class, double, marco_declare, \
macro_find_key, macro_impl)
#define REGISTER_JITKERNEL(ker_key, ker_class) \
REGISTER_JITKERNEL_ARGS(ker_key, ker_class, JITKERNEL_DEFINE_NAME, \
JITKERNEL_DECLARE, JITKERNEL_FIND_KEY, \
JITKERNEL_IMPL)
// TODO(TJ): below defines are deprecated, would be remove recently
#define SEARCH_BLOCK(macro_, ker, dtype, isa) \
if (d < YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kLT8); \
} else if (d == YMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ8); \
} else if (d > YMM_FLOAT_BLOCK && d < ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kGT8LT16); \
} else if (d == ZMM_FLOAT_BLOCK) { \
macro_(ker, dtype, isa, kEQ16); \
} else { \
macro_(ker, dtype, isa, kGT16); \
}
#define SEARCH_ISA_BLOCK(macro_, ker, dtype) \
if (platform::MayIUse(platform::avx512f)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx512f); \
} else if (platform::MayIUse(platform::avx2)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx2); \
} else if (platform::MayIUse(platform::avx)) { \
SEARCH_BLOCK(macro_, ker, dtype, platform::avx); \
} else { \
SEARCH_BLOCK(macro_, ker, dtype, platform::isa_any); \
}
#define JITKERNEL_KEY(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d)
#define JITKERNEL_NEW_IMPL_DEPRECATED(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(d))
#define JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, ker_dtype, \
dtype_key, marco_declare, macro_key, \
macro_impl) \
marco_declare(ker_class, ker_dtype) { \
std::string key = macro_key(ker_key, dtype_key); \
if (kers_.find(key) == kers_.end()) { \
std::shared_ptr<ker_class<ker_dtype>> p; \
SEARCH_ISA_BLOCK(macro_impl, ker_class, ker_dtype); \
kers_.insert({key, std::dynamic_pointer_cast<Kernel>(p)}); \
return p; \
} \
return std::dynamic_pointer_cast<const ker_class<ker_dtype>>( \
kers_.at(key)); \
}
#define REGISTER_JITKERNEL_DEPRECATED(ker_key, ker_class) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
JITKERNEL_DECLARE, JITKERNEL_KEY, \
JITKERNEL_NEW_IMPL_DEPRECATED)
#define REGISTER_JITKERNEL_ARGS_DEPRECATED(ker_key, ker_class, marco_declare, \
macro_key, macro_impl) \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, float, f, marco_declare, \
macro_key, macro_impl); \
JITKERNEL_WITH_DTYPE_DEPRECATED(ker_key, ker_class, double, d, \
marco_declare, macro_key, macro_impl)
#define FOR_EACH_ISA(macro_, block) \
macro_(platform::avx512f, block); \
macro_(platform::avx2, block); \
macro_(platform::avx, block); \
macro_(platform::isa_any, block)
#define FOR_EACH_BLOCK(macro_, isa) \
macro_(isa, kLT8); \
macro_(isa, kEQ8); \
macro_(isa, kGT8LT16); \
macro_(isa, kEQ16); \
macro_(isa, kGT16)
#define FOR_EACH_ISA_BLOCK(macro_) \
FOR_EACH_BLOCK(macro_, platform::avx512f); \
FOR_EACH_BLOCK(macro_, platform::avx2); \
FOR_EACH_BLOCK(macro_, platform::avx); \
FOR_EACH_BLOCK(macro_, platform::isa_any)
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <string>
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#ifdef PADDLE_WITH_XBYAK
#include "paddle/fluid/operators/math/jit_code.h"
#endif
namespace paddle {
namespace operators {
namespace math {
namespace jitkernel {
/* LSTM JitKernel */
template <typename T>
class LSTMKernelImpl : public LSTMKernel<T> {
public:
static inline std::string name(const lstm_attr_t& attr) {
PADDLE_THROW("DType should be either float or double");
}
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit LSTMKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 90 * 4 * 8;
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
this->ComputeC1H1 =
jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
return;
}
#endif
this->ComputeCtHt = refer::LSTMCtHt<T>;
this->ComputeC1H1 = refer::LSTMC1H1<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool LSTMKernelImpl<float>::useJIT(int d) {
return gen::LSTMJitCode::init(d);
}
#endif
/* Peephole JitKernel */
template <typename T>
class PeepholeKernelImpl : public LSTMKernel<T> {
public:
static inline std::string name(const lstm_attr_t& attr) {
PADDLE_THROW("DType should be either float or double");
}
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit PeepholeKernelImpl(const lstm_attr_t& attr) : LSTMKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 4 * 8;
jitcode0_.reset(new gen::LSTMJitCode(false, attr, sz > 4096 ? sz : 4096));
this->ComputeCtHt =
jitcode0_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
jitcode1_.reset(new gen::LSTMJitCode(true, attr, sz > 4096 ? sz : 4096));
this->ComputeC1H1 =
jitcode1_->getCode<void (*)(lstm_t*, const lstm_attr_t*)>();
return;
}
#endif
this->ComputeCtHt = refer::LSTMCtHt<T>;
this->ComputeC1H1 = refer::LSTMC1H1<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::LSTMJitCode> jitcode0_{nullptr}, jitcode1_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool PeepholeKernelImpl<float>::useJIT(int d) {
return gen::LSTMJitCode::init(d);
}
#endif
#define JITKERNEL_DEFINE_NAME_LSTM(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand + attr.act_cell + \
(attr.use_peephole ? "p" : "n")); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/ \
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const lstm_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/ \
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE_LSTM(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const LSTMKernel<ker_dtype>> \
KernelPool::Get<LSTMKernel<ker_dtype>, const lstm_attr_t&>( \
const lstm_attr_t& attr)
#define JITKERNEL_FIND_KEY_LSTM(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_LSTM_IMPL(ker, dtype) \
if (attr.use_peephole) { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<PeepholeKernelImpl<dtype>>(attr)); \
} else { \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr)); \
}
REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DEFINE_NAME_LSTM,
JITKERNEL_DECLARE_LSTM, JITKERNEL_FIND_KEY_LSTM,
JITKERNEL_LSTM_IMPL);
#undef JITKERNEL_LSTM_IMPL
#undef JITKERNEL_FIND_KEY_LSTM
#undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_DEFINE_NAME_LSTM
/* GRU JitKernel */
template <typename T>
class GRUKernelImpl : public GRUKernel<T> {
public:
static inline std::string name(const gru_attr_t& attr) {
PADDLE_THROW("DType should be either float or double");
}
static inline bool useJIT(int d) { return false; }
static inline bool useMKL(int d) { return false; }
explicit GRUKernelImpl(const gru_attr_t& attr) : GRUKernel<T>() {
#ifdef PADDLE_WITH_XBYAK
if (useJIT(attr.d)) {
size_t sz = 96 + attr.d / YMM_FLOAT_BLOCK * 96 * 2 * 8;
jitcode0_.reset(new gen::GRUJitCode(0, attr, sz > 4096 ? sz : 4096));
this->ComputeH1 =
jitcode0_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
jitcode1_.reset(new gen::GRUJitCode(1, attr, sz > 4096 ? sz : 4096));
this->ComputeHtPart1 =
jitcode1_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
jitcode2_.reset(new gen::GRUJitCode(2, attr, sz > 4096 ? sz : 4096));
this->ComputeHtPart2 =
jitcode2_->getCode<void (*)(gru_t*, const gru_attr_t*)>();
return;
}
#endif
this->ComputeH1 = refer::GRUH1<T>;
this->ComputeHtPart1 = refer::GRUHtPart1<T>;
this->ComputeHtPart2 = refer::GRUHtPart2<T>;
}
#ifdef PADDLE_WITH_XBYAK
private:
std::unique_ptr<gen::GRUJitCode> jitcode0_{nullptr}, jitcode1_{nullptr},
jitcode2_{nullptr};
#endif
};
#ifdef PADDLE_WITH_XBYAK
template <>
bool GRUKernelImpl<float>::useJIT(int d) {
return gen::GRUJitCode::init(d);
}
#endif
#define JITKERNEL_DEFINE_NAME_GRU(ker_key, ker_class) \
template <> \
std::string ker_class##Impl<float>::name(const gru_attr_t& attr) { \
std::string key(#ker_key "f"); \
key += (attr.act_gate + attr.act_cand); \
if (useJIT(attr.d)) { \
/* only jit code need record d*/ \
return key + "jit" + std::to_string(attr.d); \
} else if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
} \
template <> \
std::string ker_class##Impl<double>::name(const gru_attr_t& attr) { \
std::string key(#ker_key "d"); \
/* jit code do not support double yet*/ \
if (useMKL(attr.d)) { \
return key + "mkl"; \
} else { \
return key + "any"; \
} \
}
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const ker_class<ker_dtype>> \
KernelPool::Get<ker_class<ker_dtype>, const gru_attr_t&>( \
const gru_attr_t& attr)
#define JITKERNEL_FIND_KEY_GRU(ker_class, ker_dtype) \
std::string key = ker_class##Impl<ker_dtype>::name(attr)
#define JITKERNEL_GRU_IMPL(ker, dtype) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype>>(attr));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DEFINE_NAME_GRU,
JITKERNEL_DECLARE_GRU, JITKERNEL_FIND_KEY_GRU,
JITKERNEL_GRU_IMPL);
#undef JITKERNEL_GRU_IMPL
#undef JITKERNEL_FIND_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
#undef JITKERNEL_DEFINE_NAME_GRU
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h"
#include <cmath> // for exp
#include <cstring> // for memcpy
#include <random>
#include <string>
#include <vector>
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/operators/math/jit_kernel_refer.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif
#ifdef __AVX__
#include <immintrin.h>
#endif
constexpr int repeat = 20000;
// TODO(TJ): benchmark and test should be seperated,
// benchmark should verify more sizes
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
}
template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f),
const T upper = static_cast<T>(20.f)) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
for (int i = 0; i < n; ++i) {
a[i] = static_cast<T>(uniform_dist(rng) * (upper - lower) + lower);
}
}
#if defined __AVX__ || defined __AVX2__
void vrelu_intri8(const int n, const float* x, float* y) {
__m256 tmp = _mm256_loadu_ps(x);
tmp = _mm256_max_ps(tmp, _mm256_setzero_ps());
_mm256_storeu_ps(y, tmp);
}
#endif
TEST(JitKernel, vrelu) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -10.f, 1.f);
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VReluKernel<float>>(d);
const float* x_data = x.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VRelu<float>(x_data, zref_data, d);
}
auto trefe = GetCurrentUS();
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vrelu_intri8(d, x_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat << " us";
}
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
TEST(JitKernel, vaddbias) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 64, 100, 128, 256}) {
std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f);
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VAddBiasKernel<float>>(d);
const float a = 2.f;
const float* x_data = x.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VAddBias<float>(&a, x_data, zref_data, d);
}
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
#ifdef PADDLE_WITH_MKLML
void vexp_mkl(const int n, const float* x, float* y) {
paddle::platform::dynload::vsExp(n, x, y);
}
#endif
TEST(JitKernel, vexp) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) {
std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f);
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VExpKernel<float>>(d);
const float* x_data = x.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VExp<float>(x_data, zref_data, d);
}
auto trefe = GetCurrentUS();
#ifdef PADDLE_WITH_MKLML
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vexp_mkl(d, x_data, zref_data);
}
auto tmkle = GetCurrentUS();
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
// ker->Compute(x_data, ztgt_data);
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, "
#else
<< " us, "
#endif
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
void vsigmoid_better(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VExpKernel<float>>& vexp,
const int n, const float* x, float* y) {
const float min = SIGMOID_THRESHOLD_MIN;
const float max = SIGMOID_THRESHOLD_MAX;
for (int i = 0; i < n; ++i) {
y[i] = (x[i] < min) ? min : ((x[i] > max) ? max : x[i]);
y[i] = 0.f - y[i];
}
vexp->Compute(y, y, n);
for (int i = 0; i < n; ++i) {
y[i] = 1.f / (1.f + y[i]);
}
}
TEST(JitKernel, vsigmoid) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f);
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(d);
const auto& vexp =
jit::KernelPool::Instance().template Get<jit::VExpKernel<float>>(d);
const float* x_data = x.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vsigmoid_better(vexp, d, x_data, zref_data);
}
auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VSigmoid<float>(x_data, zref_data, d);
}
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
void vtanh_better(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VScalKernel<float>>& vscal,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
vsigmoid,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddBiasKernel<float>>&
vaddbias,
const int n, const float* x, float* y) {
const float a = 2.f, b = -1.f;
vscal->Compute(&a, x, y, n);
vsigmoid->Compute(y, y, n);
vscal->Compute(&a, y, y, n);
vaddbias->Compute(&b, y, y, n);
}
TEST(JitKernel, vtanh) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
std::vector<float> x(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data(), -2.f, 2.f);
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
const auto& vscal =
jit::KernelPool::Instance().template Get<jit::VScalKernel<float>>(d);
const auto& vsigmoid =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(d);
const auto& vaddbias =
jit::KernelPool::Instance().template Get<jit::VAddBiasKernel<float>>(d);
const float* x_data = x.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vtanh_better(vscal, vsigmoid, vaddbias, d, x_data, zref_data);
}
auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VTanh<float>(x_data, zref_data, d);
}
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit exp) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
void lstm_ctht_better(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VSigmoidKernel<float>>&
vsigmoid_3d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VTanhKernel<float>>& vtanh_d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VMulKernel<float>>& vmul_d,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd_d,
const int d, float* gates, const float* ct_1, float* ct, float* ht) {
int d2 = d * 2;
vsigmoid_3d->Compute(gates + d, gates + d, 3 * d);
vtanh_d->Compute(gates, gates, d);
vmul_d->Compute(gates, gates + d, gates + d, d);
vmul_d->Compute(ct_1, gates + d2, gates + d2, d);
vadd_d->Compute(gates + d, gates + d2, ct, d);
/* H_t = act_cell(C_t) * ogated */
vtanh_d->Compute(ct, gates + d2, d);
vmul_d->Compute(gates + d2, gates + d * 3, ht, d);
}
TEST(JitKernel, lstm) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) {
int d4 = d * 4;
int d3 = d * 3;
std::vector<float> x(d4), xref(d4);
std::vector<float> ct_1(d), ct_tgt(d), ht_tgt(d);
std::vector<float> ct_ref(d), ht_ref(d);
RandomVec<float>(d4, x.data(), -2.f, 2.f);
RandomVec<float>(d, ct_1.data(), -2.f, 2.f);
memcpy(xref.data(), x.data(), sizeof(float) * d4);
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
const jit::lstm_attr_t attr(d, act_gate, act_cand, act_cell, false);
const auto& ker =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
attr);
// below kernels are used to compute refer
const auto& vsigmoid_3d =
jit::KernelPool::Instance().template Get<jit::VSigmoidKernel<float>>(
d3);
const auto& vtanh_d =
jit::KernelPool::Instance().template Get<jit::VTanhKernel<float>>(d);
const auto& vmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
const auto& vadd_d =
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
float* x_data = x.data();
float* xref_data = xref.data();
const float* ct_1_data = ct_1.data();
float* ct_tgt_data = ct_tgt.data();
float* ht_tgt_data = ht_tgt.data();
float* ct_ref_data = ct_ref.data();
float* ht_ref_data = ht_ref.data();
// compute once to check correctness
jit::lstm_t step;
step.gates = xref_data;
step.ct_1 = ct_1_data;
step.ct = ct_ref_data;
step.ht = ht_ref_data;
refer::LSTMCtHt<float>(&step, &attr);
step.gates = x_data;
step.ct = ct_tgt_data;
step.ht = ht_tgt_data;
ker->ComputeCtHt(&step, &attr);
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ct_tgt_data[i], ct_ref_data[i], 1e-3);
EXPECT_NEAR(ht_tgt_data[i], ht_ref_data[i], 1e-3);
}
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
lstm_ctht_better(vsigmoid_3d, vtanh_d, vmul_d, vadd_d, d, xref_data,
ct_1_data, ct_ref_data, ht_ref_data);
}
auto tmkle = GetCurrentUS();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::LSTMCtHt<float>(&step, &attr);
}
auto trefe = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->ComputeCtHt(&step, &attr);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better(jit) takes: " << (tmkle - tmkls) / repeat
<< " us, tgt takes: " << (ttgte - ttgts) / repeat << " us";
}
}
#if defined __AVX__ || defined __AVX2__
void vscal_intri8(const int n, const float a, const float* x, float* y) {
__m256 tmp;
__m256 scalar = _mm256_set1_ps(a);
tmp = _mm256_loadu_ps(x);
tmp = _mm256_mul_ps(tmp, scalar);
_mm256_storeu_ps(y, tmp);
}
void vscal_inp_intri8(const int n, const float a, float* x) {
__m256 tmp;
__m256 scalar = _mm256_set1_ps(a);
tmp = _mm256_loadu_ps(x);
tmp = _mm256_mul_ps(tmp, scalar);
_mm256_storeu_ps(x, tmp);
}
#endif
#ifdef PADDLE_WITH_MKLML
void vscal_inp_mkl(const int n, const float a, float* x) {
paddle::platform::dynload::cblas_sscal(n, a, x, 1);
}
#endif
TEST(JitKernel, vscal) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
std::memcpy(y.data(), x.data(), sizeof(float) * d);
float a = 2.f;
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VScalKernel<float>>(d);
const float* x_data = x.data();
float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VScal<float>(&a, x_data, zref_data, d);
}
auto trefe = GetCurrentUS();
auto trefs1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VScal<float>(&a, y_data, y_data, d);
}
auto trefe1 = GetCurrentUS();
#ifdef PADDLE_WITH_MKLML
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_inp_mkl(d, a, y_data);
}
auto tmkle = GetCurrentUS();
#endif
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_intri8(d, a, x_data, zref_data);
}
auto si1 = GetCurrentUS();
auto si2 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vscal_inp_intri8(d, a, y_data);
}
auto si3 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat
<< " us, inplace: " << (si3 - si2) / repeat << " us";
}
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(&a, x_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
auto ttgts1 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(&a, y_data, y_data, d);
}
auto ttgte1 = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, inplace takes: " << (trefe1 - trefs1) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl inplace takes: " << (tmkle - tmkls) / repeat << " us, "
#else
<< " us, "
#endif
<< "tgt takes: " << (ttgte - ttgts) / repeat
<< "us, tgt inplace takes: " << (ttgte1 - ttgts1) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
#if defined __AVX__ || defined __AVX2__
void vmul_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
tmpx = _mm256_mul_ps(tmpx, tmpy);
_mm256_storeu_ps(z, tmpx);
}
#endif
#ifdef PADDLE_WITH_MKLML
void vmul_mkl(const int n, const float* x, const float* y, float* z) {
paddle::platform::dynload::vsMul(n, x, y, z);
}
#endif
TEST(JitKernel, vmul) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 20, 30, 256, 512, 1000, 1024}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
RandomVec<float>(d, y.data());
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(d);
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VMul<float>(x_data, y_data, zref_data, d);
}
auto trefe = GetCurrentUS();
#ifdef PADDLE_WITH_MKLML
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vmul_mkl(d, x_data, y_data, zref_data);
}
auto tmkle = GetCurrentUS();
#endif
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vmul_intri8(d, x_data, y_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
}
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, "
#else
<< " us, "
#endif
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
#if defined __AVX__ || defined __AVX2__
void vadd_intri8(const int n, const float* x, const float* y, float* z) {
__m256 tmpx, tmpy;
tmpx = _mm256_loadu_ps(x);
tmpy = _mm256_loadu_ps(y);
tmpx = _mm256_add_ps(tmpx, tmpy);
_mm256_storeu_ps(z, tmpx);
}
#endif
#ifdef PADDLE_WITH_MKLML
void vadd_mkl(const int n, const float* x, const float* y, float* z) {
paddle::platform::dynload::vsAdd(n, x, y, z);
}
#endif
TEST(JitKernel, vadd) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
RandomVec<float>(d, y.data());
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VAdd<float>(x_data, y_data, zref_data, d);
}
auto trefe = GetCurrentUS();
#ifdef PADDLE_WITH_MKLML
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_mkl(d, x_data, y_data, zref_data);
}
auto tmkle = GetCurrentUS();
#endif
#if defined __AVX__ || defined __AVX2__
if (d == 8) {
auto si0 = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vadd_intri8(d, x_data, y_data, zref_data);
}
auto si1 = GetCurrentUS();
VLOG(3) << "Vec size 8 intr takes: " << (si1 - si0) / repeat;
}
#endif
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
#ifdef PADDLE_WITH_MKLML
<< " us, mkl takes: " << (tmkle - tmkls) / repeat << " us, "
#else
<< " us, "
#endif
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
void vaddrelu_better(
const std::shared_ptr<
const paddle::operators::math::jitkernel::VAddKernel<float>>& vadd,
const std::shared_ptr<
const paddle::operators::math::jitkernel::VReluKernel<float>>& vrelu,
const float* x, const float* y, float* z, int d) {
vadd->Compute(x, y, z, d);
vrelu->Compute(z, z, d);
}
TEST(JitKernel, vaddrelu) {
namespace jit = paddle::operators::math::jitkernel;
namespace refer = paddle::operators::math::jitkernel::refer;
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
std::vector<float> x(d), y(d);
std::vector<float> zref(d), ztgt(d);
RandomVec<float>(d, x.data());
RandomVec<float>(d, y.data());
const auto& ker =
jit::KernelPool::Instance().template Get<jit::VAddReluKernel<float>>(d);
const auto& vadd =
jit::KernelPool::Instance().template Get<jit::VAddKernel<float>>(d);
const auto& vrelu =
jit::KernelPool::Instance().template Get<jit::VReluKernel<float>>(d);
const float* x_data = x.data();
const float* y_data = y.data();
float* ztgt_data = ztgt.data();
float* zref_data = zref.data();
auto trefs = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
refer::VAddRelu<float>(x_data, y_data, zref_data, d);
}
auto trefe = GetCurrentUS();
auto tmkls = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
vaddrelu_better(vadd, vrelu, x_data, y_data, zref_data, d);
}
auto tmkle = GetCurrentUS();
auto ttgts = GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
ker->Compute(x_data, y_data, ztgt_data, d);
}
auto ttgte = GetCurrentUS();
VLOG(3) << "Vec size " << d << ": refer takes: " << (trefe - trefs) / repeat
<< " us, better takes: " << (tmkle - tmkls) / repeat << " us, "
<< "tgt takes: " << (ttgte - ttgts) / repeat << " us";
for (int i = 0; i < d; ++i) {
EXPECT_NEAR(ztgt_data[i], zref_data[i], 1e-3);
}
}
}
TEST(JitKernel, pool) {
namespace jit = paddle::operators::math::jitkernel;
const int frame_size = 4;
std::string act_gate = "sigmoid", act_cand = "tanh", act_cell = "tanh";
jit::lstm_attr_t attr(frame_size, act_gate, act_cand, act_cell, false);
// empty call it to avoid unknown flag 'use_pinned_memory' on Mac
paddle::platform::MayIUse(paddle::platform::avx);
const auto& plstm1 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
const auto& plstm2 =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(attr);
EXPECT_EQ(plstm1, plstm2);
const auto& peephole =
jit::KernelPool::Instance()
.template Get<jit::LSTMKernel<float>, const jit::lstm_attr_t&>(
jit::lstm_attr_t(frame_size, act_gate, act_cand, act_cell, true));
EXPECT_TRUE(plstm1 != peephole);
const auto& pvmul_f =
jit::KernelPool::Instance().template Get<jit::VMulKernel<float>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(plstm2) !=
std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f));
const auto& pvmul_d =
jit::KernelPool::Instance().template Get<jit::VMulKernel<double>>(4);
EXPECT_TRUE(std::dynamic_pointer_cast<const jit::Kernel>(pvmul_f) !=
std::dynamic_pointer_cast<const jit::Kernel>(pvmul_d));
const auto& pvmul_from_key = jit::KernelPool::Instance().Get("vmulfjit4");
#if defined(__APPLE__) || defined(__OSX__) || defined(_WIN32)
EXPECT_EQ(pvmul_from_key, nullptr);
#else
EXPECT_EQ(pvmul_from_key, pvmul_f);
#endif
const auto& pvmul_from_key2 = jit::KernelPool::Instance().Get("vmulfjit");
EXPECT_TRUE(pvmul_from_key2 == nullptr);
}
...@@ -26,6 +26,13 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -26,6 +26,13 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel {
"Input(X) of MergeSelectedRowsOp should not be null."); "Input(X) of MergeSelectedRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MergeSelectedRowsOp should not be null."); "Output(Out) of MergeSelectedRowsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(),
framework::proto::VarType::SELECTED_ROWS,
"Input X only should be SelectedRows.");
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
framework::proto::VarType::SELECTED_ROWS,
"Output Y only should be SelectedRows.");
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
} }
}; };
...@@ -43,7 +50,28 @@ class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -43,7 +50,28 @@ class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
R"DOC( R"DOC(
MergeSelectedRows Operator. MergeSelectedRows Operator.
MergeSelectedRows is used to merge the duplicated rows of the input. MergeSelectedRows is used to merge the duplicated rows of the input. The
output's row has no duplicated, and it's order is incremental.
Example:
Input:
X.rows is [0, 5, 5, 4, 19]
X.height is 20
X.value is:
[[1, 1]
[2, 2]
[3, 3]
[4, 4]
[6, 6]]
Output:
Out.row is [0, 4, 5, 19]
Out.height is 20
Out.value is:
[[1, 1]
[4, 4]
[5, 5]
[6, 6]]
)DOC"); )DOC");
} }
}; };
......
...@@ -49,7 +49,8 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -49,7 +49,8 @@ class MulOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims, y_dims.size(), y_num_col_dims,
"The input tensor Y's rank of MulOp should be larger than " "The input tensor Y's rank of MulOp should be larger than "
"y_num_col_dims."); "y_num_col_dims: %ld vs %ld",
y_dims.size(), y_num_col_dims);
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
......
...@@ -12,28 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,28 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/jit_kernel.h" /*
#include <iostream> * This file contains the list of the ngraph operators for Paddle.
#include <string> *
* ATTENTION: It requires some C++11 features, for lower version C++ or C, we
* might release another API.
*/
namespace paddle { #pragma once
namespace operators {
namespace math {
namespace jitkernel {
KernelPool& KernelPool::Instance() { #include "ops/binary_unnary_op.h"
static thread_local KernelPool g_jit_kernels; #include "ops/mul_op.h"
return g_jit_kernels;
}
std::shared_ptr<const Kernel> KernelPool::Get(const std::string& key) const {
if (kers_.find(key) == kers_.end()) {
return nullptr;
}
return kers_.at(key);
}
} // namespace jitkernel
} // namespace math
} // namespace operators
} // namespace paddle
...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); ...@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
...@@ -12,62 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,62 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once #pragma once
#include <string> #include <string>
#include <type_traits> #include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace ngraphs {
namespace jitkernel {
template <typename T>
#define SIGMOID_THRESHOLD_MIN -40.0 static void BuildBinaryNode(
#define SIGMOID_THRESHOLD_MAX 13.0 const std::shared_ptr<paddle::framework::OperatorBase>& op,
#define EXP_MAX_INPUT 40.0 std::shared_ptr<
#define XMM_FLOAT_BLOCK 4 std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
#define YMM_FLOAT_BLOCK 8 ngb_node_map) {
#define ZMM_FLOAT_BLOCK 16 auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
typedef struct { auto out = std::make_shared<T>(x, y);
void* gates; // gates: W_ch, W_ih, W_fh, W_oh paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
const void* ct_1; }
void* ct;
void* ht; template <typename T>
/* weight_peephole and checked data are only used in peephole*/ static void BuildUnaryNode(
const void* wp{nullptr}; const std::shared_ptr<paddle::framework::OperatorBase>& op,
void* checked{nullptr}; std::shared_ptr<
} lstm_t; std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
typedef struct { auto input = paddle::platform::GetInputNode(op, "X", ngb_node_map);
void* gates; // gates: {W_update, W_reset; W_state} auto out = std::make_shared<T>(input);
const void* ht_1; paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
void* ht; }
} gru_t;
} // namespace ngraphs
struct rnn_attr_s {
int d;
std::string act_gate, act_cand;
rnn_attr_s() = default;
rnn_attr_s(int _d, const std::string& _act_gate, const std::string& _act_cand)
: d(_d), act_gate(_act_gate), act_cand(_act_cand) {}
};
struct lstm_attr_s : public rnn_attr_s {
bool use_peephole;
std::string act_cell;
lstm_attr_s() = default;
lstm_attr_s(int _d, const std::string& _act_gate,
const std::string& _act_cand, const std::string& _act_cell,
bool _use_peephole = false)
: rnn_attr_s(_d, _act_gate, _act_cand),
use_peephole(_use_peephole),
act_cell(_act_cell) {}
};
typedef struct rnn_attr_s gru_attr_t;
typedef struct lstm_attr_s lstm_attr_t;
} // namespace jitkernel
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <string>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
static void BuildMulNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
int x_num_col_dims = op_attrs.Get<int>("x_num_col_dims");
int y_num_col_dims = op_attrs.Get<int>("y_num_col_dims");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto x_reshape = x;
auto y_reshape = y;
if (x->get_shape().size() > 2) {
auto x_2d = paddle::platform::FlattenTo2d(x->get_shape(), x_num_col_dims);
x_reshape = paddle::platform::NgReshaper(x, x_2d);
}
if (y->get_shape().size() > 2) {
auto y_2d = paddle::platform::FlattenTo2d(y->get_shape(), y_num_col_dims);
y_reshape = paddle::platform::NgReshaper(y, y_2d);
}
std::shared_ptr<ngraph::Node> out =
std::make_shared<ngraph::op::Dot>(x_reshape, y_reshape);
auto dummy_out = paddle::platform::GetOutputNode(op, "Out", ngb_node_map);
if (dummy_out && dummy_out->get_shape() != out->get_shape()) {
out = paddle::platform::NgReshaper(out, dummy_out->get_shape());
}
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
static void BuildMulGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
int x_num_col_dims = op_attrs.Get<int>("x_num_col_dims");
int y_num_col_dims = op_attrs.Get<int>("y_num_col_dims");
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
bool is_dx = paddle::platform::HasOutput(op, "X@GRAD") ? true : false;
bool is_dy = paddle::platform::HasOutput(op, "Y@GRAD") ? true : false;
auto x_shape = x->get_shape();
auto y_shape = y->get_shape();
auto x_reshape = x;
auto y_reshape = y;
if (x_shape.size() > 2) {
auto x_2d_shape = paddle::platform::FlattenTo2d(x_shape, x_num_col_dims);
x_reshape = paddle::platform::NgReshaper(x, x_2d_shape);
}
if (y_shape.size() > 2) {
auto y_2d_shape = paddle::platform::FlattenTo2d(y_shape, y_num_col_dims);
y_reshape = paddle::platform::NgReshaper(y, y_2d_shape);
}
auto x_reshape_shape = x_reshape->get_shape();
std::reverse(x_reshape_shape.begin(), x_reshape_shape.end());
auto x_transpose = std::make_shared<ngraph::op::Reshape>(
x_reshape, ngraph::AxisVector{1, 0}, x_reshape_shape);
auto y_reshape_shape = y_reshape->get_shape();
std::reverse(y_reshape_shape.begin(), y_reshape_shape.end());
auto y_transpose = std::make_shared<ngraph::op::Reshape>(
y_reshape, ngraph::AxisVector{1, 0}, y_reshape_shape);
if (is_dx) {
if (dout->get_shape().size() > 2) {
auto dout_2d_shape = paddle::platform::FlattenTo2d(dout->get_shape(), 2);
dout = paddle::platform::NgReshaper(dout, dout_2d_shape);
}
auto dx = std::make_shared<ngraph::op::Dot>(dout, y_transpose);
if (dx->get_shape() == x_shape) {
paddle::platform::SetOutputNode(op, "X@GRAD", dx, ngb_node_map);
} else {
auto dx_reshape = paddle::platform::NgReshaper(dx, x_shape);
paddle::platform::SetOutputNode(op, "X@GRAD", dx_reshape, ngb_node_map);
}
}
if (is_dy) {
if (dout->get_shape().size() > 2) {
auto dout_2d_shape = paddle::platform::FlattenTo2d(dout->get_shape(), 2);
dout = paddle::platform::NgReshaper(dout, dout_2d_shape);
}
auto dy = std::make_shared<ngraph::op::Dot>(x_transpose, dout);
if (dy->get_shape() == y_shape) {
paddle::platform::SetOutputNode(op, "Y@GRAD", dy, ngb_node_map);
} else {
auto dy_reshape = paddle::platform::NgReshaper(dy, y_shape);
paddle::platform::SetOutputNode(op, "Y@GRAD", dy_reshape, ngb_node_map);
}
}
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
#endif
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/py_func_op.h"
#include <set>
#include <string>
#include <vector>
#include "Python.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
namespace py = ::pybind11;
static std::vector<py::object> g_py_callables;
const char kForwardPythonCallableId[] = "forward_callable_id";
const char kBackwardPythonCallableId[] = "backward_callable_id";
const char kPyFuncBackwardSkipVars[] = "backward_skip_vars";
size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
g_py_callables.emplace_back(py_obj);
return g_py_callables.size() - 1;
}
// Return py::object* instead of py::object
// Returning py::object would cause reference count increasing
// but without GIL, reference count in Python may not be safe
static py::object *GetPythonCallableObject(size_t i) {
PADDLE_ENFORCE_LT(i, g_py_callables.size(), "Invalid python callable id");
return &g_py_callables[i];
}
static std::string PythonFuncDebugString(const py::object &py_callable) {
py::gil_scoped_acquire guard;
std::string wrapper_func_str = py::str(py_callable);
auto inner_func = py_callable.attr("_func");
std::string inner_func_str = py::str(inner_func);
return inner_func_str + " wrapped by " + wrapper_func_str;
}
static void CallPythonFunc(py::object *callable,
const std::vector<framework::LoDTensor> &ins,
std::vector<framework::LoDTensor *> *outs) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
}
auto ret = (*callable)(*in_args);
auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple);
size_t out_num = outs->size();
if (UNLIKELY(ret_num != out_num)) {
// Python function has no return values or returns None
// In this case, ret_num = 1 && ret[0] == None && out_num should be 0
// Otherwise, ret_num must be equal to out_num
PADDLE_ENFORCE(
ret_num == 1 && out_num == 0 &&
py::cast<framework::LoDTensor *>(ret_tuple[0]) == nullptr,
"Output number not match. Expected %d, actual %d", out_num, ret_num);
}
for (size_t i = 0; i < out_num; ++i) {
auto *out = (*outs)[i];
if (out == nullptr) {
continue;
}
try {
auto *py_out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
"Output tensor %d should not be nullptr", i);
out->set_lod(py_out_tensor->lod());
out->ShareDataWith(*py_out_tensor);
} catch (py::cast_error &) {
PADDLE_THROW("The %d-th output must be LoDTensor", i);
}
}
}
class PyFuncOpVarTypInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op,
framework::BlockDesc *block) const override {
auto &outs = op.Outputs();
bool has_out = (outs.count("Out") > 0 && !outs.at("Out").empty());
auto &ins = op.Inputs();
bool has_in = (ins.count("X") > 0 && !ins.at("X").empty());
/**
* X or Out can be empty, so that py_func can be more flexible
* to support Python functions with no input or no output
*/
PADDLE_ENFORCE(has_in || has_out, "Input(X) or Output(Out) must exist");
PADDLE_ENFORCE_GE(boost::get<int>(op.GetAttr(kForwardPythonCallableId)), 0,
"Function id cannot be less than 0");
if (!has_out) return;
/**
* Traverse all outputs, check if name of any output ends with @GRAD.
* If found, set its shape, dtype, lod_level, type to be the same as
* the corresponding forward variable
*/
const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto &out_var_names = outs.at("Out");
for (auto &out_var_name : out_var_names) {
if (out_var_name == framework::kEmptyVarName ||
out_var_name.size() < kGradVarSuffix.size()) {
continue;
}
size_t len = out_var_name.size() - kGradVarSuffix.size();
if (out_var_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_var_name.substr(0, len);
auto *out_var_desc = block->FindVarRecursive(out_var_name);
auto *fwd_var_desc = block->FindVarRecursive(fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(out_var_desc, "Backward variable %s not found",
out_var_name);
PADDLE_ENFORCE_NOT_NULL(fwd_var_desc, "Forward variable %s not found",
fwd_var_name);
VLOG(10) << "Infer var_desc of Output(" << out_var_name << ") as Input("
<< fwd_var_name << ")";
out_var_desc->SetShape(fwd_var_desc->GetShape());
out_var_desc->SetDataType(fwd_var_desc->GetDataType());
out_var_desc->SetLoDLevel(fwd_var_desc->GetLoDLevel());
out_var_desc->SetType(fwd_var_desc->GetType());
}
}
}
};
class PyFuncOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"Infer shape cannot be called in runtime.");
}
};
class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Inputs of py_func op.").AsDuplicable();
AddOutput("Out", "Outputs of py_func op").AsDuplicable();
AddAttr<int>(kForwardPythonCallableId,
"Index of registered forward Python function.")
.SetDefault(0);
AddAttr<int>(kBackwardPythonCallableId,
"Index of registered backward Python function.")
.SetDefault(-1);
AddAttr<std::vector<std::string>>(kPyFuncBackwardSkipVars,
"Unused forward in/out in backward op")
.SetDefault(std::vector<std::string>());
AddComment(R"DOC("PyFunc Op")DOC");
}
};
/**
* There are several benefits when backward op of py_func op is
* still py_func op.
*
* - Less codes are needed, since codes of backward is almost
* the same as forward.
*
* - To support high order derivative, so that py_func is
* infinite-order differentiable
*/
class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
private:
static std::string DebugString(const std::vector<std::string> &strs) {
if (strs.empty()) return "";
std::string ret = strs[0];
for (size_t i = 1; i < strs.size(); ++i) {
ret += " ";
ret += strs[i];
}
return ret;
}
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto &fwd_attrs = Attrs();
// no backward op when backward_id is less than 0
if (boost::get<int>(fwd_attrs.at(kBackwardPythonCallableId)) < 0) {
return {};
}
std::unique_ptr<framework::OpDesc> grad_op(new framework::OpDesc());
grad_op->SetType("py_func");
framework::AttributeMap bwd_attrs;
bwd_attrs[kForwardPythonCallableId] =
fwd_attrs.at(kBackwardPythonCallableId);
bwd_attrs[kBackwardPythonCallableId] = -1;
grad_op->SetAttrMap(bwd_attrs);
// All forward inputs
auto fwd_ins = Input("X");
// All forward outputs
auto fwd_outs = Output("Out");
// For memory reused, some inputs/output in forward part may be not needed
// in backward part. Skipping these vars helps to save memory
auto &backward_skip_var_list = boost::get<std::vector<std::string>>(
fwd_attrs.at(kPyFuncBackwardSkipVars));
std::unordered_set<std::string> backward_skip_var_set(
backward_skip_var_list.begin(), backward_skip_var_list.end());
std::vector<std::string> bwd_ins;
bwd_ins.reserve(fwd_ins.size() + fwd_outs.size());
for (auto &fwd_in : fwd_ins) {
if (backward_skip_var_set.count(fwd_in) == 0) {
bwd_ins.emplace_back(fwd_in);
}
}
for (auto &fwd_out : fwd_outs) {
if (backward_skip_var_set.count(fwd_out) == 0) {
bwd_ins.emplace_back(fwd_out);
}
}
// Backward OG cannot be skipped
// But in Python side, if OG is kEmptyVarName, input tensor would be None
auto fwd_out_grads = OutputGrad("Out");
bwd_ins.reserve(bwd_ins.size() + fwd_out_grads.size());
bwd_ins.insert(bwd_ins.end(), fwd_out_grads.begin(), fwd_out_grads.end());
// Backward IG cannot be skipped
// But in Python side, if IG is not needed, users can just return None
auto bwd_outs = InputGrad("X", false);
VLOG(10) << "PyFunc Grad Input: " << DebugString(bwd_ins);
VLOG(10) << "PyFunc Grad Output: " << DebugString(bwd_outs);
grad_op->SetInput("X", bwd_ins);
grad_op->SetOutput("Out", bwd_outs);
std::vector<std::unique_ptr<framework::OpDesc>> ret(1);
ret[0] = std::move(grad_op);
return ret;
}
};
class PyFuncOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
protected:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &in_arg_names = Inputs("X");
auto &out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> inputs(in_arg_names.size());
for (size_t i = 0; i < in_arg_names.size(); ++i) {
auto in_var = scope.FindVar(in_arg_names[i]);
// When py_func op is called in backward, in_var may be null
if (in_var == nullptr) {
continue;
}
auto &in_tensor = in_var->Get<framework::LoDTensor>();
if (!in_tensor.IsInitialized()) {
continue;
}
if (platform::is_gpu_place(in_tensor.place())) {
framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]);
} else {
inputs[i].ShareDataWith(in_tensor);
}
inputs[i].set_lod(in_tensor.lod());
}
std::vector<framework::LoDTensor *> outputs(out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) {
auto *out_var = scope.FindVar(out_arg_names[i]);
outputs[i] =
out_var ? out_var->GetMutable<framework::LoDTensor>() : nullptr;
}
auto callable_id = static_cast<size_t>(Attr<int>(kForwardPythonCallableId));
auto *py_callable = GetPythonCallableObject(callable_id);
VLOG(10) << "Call Python function with id " << callable_id << ": "
<< PythonFuncDebugString(*py_callable);
CallPythonFunc(py_callable, inputs, &outputs);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpVarTypInference, ops::PyFuncOpShapeInference,
ops::PyFuncOpGradDescMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
namespace paddle {
namespace operators {
size_t AppendPythonCallableObjectAndReturnId(const ::pybind11::object &py_obj);
} // namespace operators
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using framework::DataLayout;
template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE(
is_test == true,
"TransposeMKLDNN works only for inference!. Set is_test = True");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
int ndims = axis.size();
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
const T* input_data = input->data<T>();
if (ndims == 1) {
output->ShareDataWith(*input);
return;
}
std::vector<int> nchw_tz = paddle::framework::vectorize2int(input->dims());
const std::string key = platform::TransposeMKLDNNHandler::GetHash(
nchw_tz, axis, ctx.op().Output("Out"));
platform::TransposeMKLDNNHandler handler(nchw_tz, axis, dev_ctx,
mkldnn_engine, key);
auto transpose_src_memory_p = handler.AcquireSrcMemory(
input->format(), platform::to_void_cast<T>(input_data));
auto transpose_dst_memory_p =
handler.AcquireDstMemory(output, ctx.GetPlace());
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
transpose_src_memory_p);
std::vector<mkldnn::primitive> pipeline;
pipeline.push_back(*transpose_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace,
ops::TransposeMKLDNNOpKernel<float>);
...@@ -16,6 +16,10 @@ limitations under the License. */ ...@@ -16,6 +16,10 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,11 +57,32 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -53,11 +57,32 @@ class TransposeOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", out_dims); ctx->SetOutputDim("Out", out_dims);
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
}
}; };
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput( AddInput(
"X", "X",
"(Tensor) The input tensor, tensors with rank up to 6 are supported."); "(Tensor) The input tensor, tensors with rank up to 6 are supported.");
...@@ -67,6 +92,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -67,6 +92,16 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int>) A list of values, and the size of the list should be " "(vector<int>) A list of values, and the size of the list should be "
"the same with the input tensor rank. This operator permutes the input " "the same with the input tensor rank. This operator permutes the input "
"tensor's axes according to the values given."); "tensor's axes according to the values given.");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
"An optional string from: \"NHWC\", \"NCHW\". "
"Defaults to \"NHWC\". Specify the data format of the output data, "
"the input will be transformed automatically. ")
.SetDefault("AnyLayout");
AddComment(R"DOC( AddComment(R"DOC(
Transpose Operator. Transpose Operator.
...@@ -144,8 +179,18 @@ class Transpose2Op : public TransposeOp { ...@@ -144,8 +179,18 @@ class Transpose2Op : public TransposeOp {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), framework::LibraryType library_{framework::LibraryType::kPlain};
ctx.device_context()); std::string data_format = ctx.Attr<std::string>("data_format");
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
} }
}; };
......
...@@ -197,6 +197,130 @@ class MKLDNNHandler { ...@@ -197,6 +197,130 @@ class MKLDNNHandler {
bool is_reusing_; bool is_reusing_;
}; };
class TransposeMKLDNNHandler : public MKLDNNHandler {
public:
TransposeMKLDNNHandler(std::vector<int>& dims, std::vector<int>& axis,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dims_(dims),
axis_(axis),
logical_axis_(dims.size(), 0) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::format& fmt, void* ptr) {
auto local_key = key_ + "@user_src_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
// Make memory descriptor using input format, unless it
// cannot be trusted (nchw) then make up memory fmt manually
for (size_t i = 0; i < logical_axis_.size(); ++i) {
logical_axis_[i] = i;
}
auto src_md = fmt != mkldnn::memory::format::nchw
? platform::MKLDNNMemDesc(
dims_, platform::MKLDNNGetDataType<float>(), fmt)
: Axis2MemoryDesc(dims_, logical_axis_);
mem_p = std::make_shared<mkldnn::memory>(
mkldnn::memory::primitive_desc{src_md, engine_}, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output,
platform::Place place) {
auto local_key = key_ + "@user_dst_mem_p";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
" find mem primitive in device context");
if (mem_p == nullptr) {
auto dst_mdp = mkldnn::memory::primitive_desc{
Axis2MemoryDesc(dims_, axis_), engine_};
auto dst_data = output->mutable_data<float>(
place, paddle::memory::Allocator::kDefault, dst_mdp.get_size());
mem_p = std::make_shared<mkldnn::memory>(dst_mdp, dst_data);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
auto dst_data = output->mutable_data<float>(place);
mem_p->set_data_handle(dst_data);
// Mark that reusing happenned. All primitives from operator instance
// should be reused or none of them. So we check consistency
is_reusing_ = true;
}
return mem_p;
}
std::shared_ptr<mkldnn::reorder> AcquireTranspose(
std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) {
auto prim_key = key_ + "@transpose_p";
auto transpose_p =
std::static_pointer_cast<mkldnn::reorder>(dev_ctx_.GetBlob(prim_key));
PADDLE_ENFORCE((transpose_p != nullptr) || (is_reusing_ == false),
"Fail to find convolution primitive in device context");
if (transpose_p == nullptr) {
transpose_p =
std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p));
dev_ctx_.SetBlob(prim_key, transpose_p);
} else {
is_reusing_ = true;
}
return transpose_p;
}
static std::string GetHash(std::vector<int>& shape, // NOLINT
std::vector<int>& axis, // NOLINT
const std::string& suffix) {
return dims2str(shape) + dims2str(axis) + suffix;
}
protected:
mkldnn_memory_desc_t Axis2MemoryDesc(std::vector<int>& nchw_tz,
std::vector<int>& axis) {
mkldnn_memory_desc_t mem_fmt;
mem_fmt.primitive_kind = mkldnn_memory;
mem_fmt.ndims = axis.size();
for (unsigned int i = 0; i < nchw_tz.size(); ++i) {
mem_fmt.dims[i] = nchw_tz[i]; // logical dimensions (nchw format,
// regardless physical layout)
}
mem_fmt.data_type = mkldnn_f32;
mem_fmt.format = mkldnn_blocked;
unsigned int total_stride = 1;
for (int i = nchw_tz.size() - 1; i >= 0; --i) {
mem_fmt.layout_desc.blocking.padding_dims[i] =
nchw_tz[i]; // logical dimensions (nchw format, regardless physical
// layout)
mem_fmt.layout_desc.blocking.block_dims[i] = 1;
mem_fmt.layout_desc.blocking.offset_padding_to_data[i] = 0; // no offset
mem_fmt.layout_desc.blocking.strides[0][axis[i]] = total_stride;
mem_fmt.layout_desc.blocking.strides[1][axis[i]] = 1;
total_stride *= nchw_tz[axis[i]];
}
mem_fmt.layout_desc.blocking.offset_padding = 0; // no initial offset
return mem_fmt;
}
private:
std::vector<int> dims_;
std::vector<int> axis_;
std::vector<int> logical_axis_;
};
template <class forward_t, class backward_data_t, class backward_weights_t> template <class forward_t, class backward_data_t, class backward_weights_t>
class ConvMKLDNNTemplateHandler : public MKLDNNHandler { class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
public: public:
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH
#pragma once
#include <functional>
#include <string>
#include <vector>
#include "ngraph/ngraph.hpp"
namespace paddle {
namespace platform {
static ngraph::Shape FlattenTo2d(ngraph::Shape sh, int num) {
auto x1 = std::accumulate(std::begin(sh), std::begin(sh) + num, 1,
std::multiplies<size_t>());
auto x2 = std::accumulate(std::begin(sh) + num, std::end(sh), 1,
std::multiplies<size_t>());
size_t x1_l = static_cast<size_t>(x1);
size_t x2_l = static_cast<size_t>(x2);
return ngraph::Shape{x1_l, x2_l};
}
static std::shared_ptr<ngraph::Node> NgReshaper(
std::shared_ptr<ngraph::Node> input, ngraph::Shape shape) {
std::vector<size_t> input_order(input->get_shape().size());
std::iota(std::begin(input_order), std::end(input_order), 0);
return std::make_shared<ngraph::op::Reshape>(
input, ngraph::AxisVector(input_order), shape);
}
static std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, const paddle::framework::VariableNameMap& var_map,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = var_map.at(prm);
PADDLE_ENFORCE_EQ(var_names.size(), 1,
"op %s prm %s expects one associated var", op->Type(), prm);
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
return (*ngb_node_map)[var_names[0]];
} else {
return nullptr;
}
}
static std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, prm, op->Inputs(), ngb_node_map);
}
static std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, prm, op->Outputs(), ngb_node_map);
}
static void SetOutputNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm, std::shared_ptr<ngraph::Node> node,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = op->Outputs().at(prm);
if (var_names.size() == 1) {
(*ngb_node_map)[var_names[0]] = node;
} else if (var_names.size() == 0) {
(*ngb_node_map)[""] = node;
} else {
PADDLE_THROW("prm %s has more than 1 var_names.", prm);
}
}
static bool HasOutput(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
const std::string prm) {
auto& outputs = op->Outputs();
if (outputs.find(prm) == outputs.end()) return false;
return outputs.at(prm).size() > 0;
}
} // namespace platform
} // namespace paddle
#endif
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer) set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer)
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc)
if(WITH_PYTHON) if(WITH_PYTHON)
......
...@@ -24,8 +24,9 @@ namespace pybind { ...@@ -24,8 +24,9 @@ namespace pybind {
void BindTracer(pybind11::module *m) { void BindTracer(pybind11::module *m) {
pybind11::class_<imperative::Tracer>(*m, "Tracer", "") pybind11::class_<imperative::Tracer>(*m, "Tracer", "")
.def("__init__", .def("__init__",
[](imperative::Tracer &self, framework::BlockDesc *root_block) { [](imperative::Tracer &self, framework::BlockDesc *root_block,
new (&self) imperative::Tracer(root_block); framework::BlockDesc *startup_block) {
new (&self) imperative::Tracer(root_block, startup_block);
}) })
.def("trace", &imperative::Tracer::Trace) .def("trace", &imperative::Tracer::Trace)
.def("get_scope", &imperative::Tracer::GetScope, .def("get_scope", &imperative::Tracer::GetScope,
......
...@@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) {
.def("infer_var_type", &pd::OpDesc::InferVarType) .def("infer_var_type", &pd::OpDesc::InferVarType)
.def("set_is_target", &pd::OpDesc::SetIsTarget) .def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>) .def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", &pd::OpDesc::Block, .def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference); pybind11::return_value_policy::reference);
} }
......
...@@ -37,6 +37,7 @@ limitations under the License. */ ...@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" #include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -110,6 +111,12 @@ PYBIND11_MODULE(core, m) { ...@@ -110,6 +111,12 @@ PYBIND11_MODULE(core, m) {
BindException(&m); BindException(&m);
m.def(
"_append_python_callable_object_and_return_id",
[](py::object py_obj) -> size_t {
return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj);
});
py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC") py::class_<imperative::VarBase, PyVarBase>(m, "VarBase", R"DOC()DOC")
.def(py::init<>()) .def(py::init<>())
.def("_run_backward", .def("_run_backward",
...@@ -977,7 +984,6 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -977,7 +984,6 @@ All parameter, weight, gradient are variables in Paddle.
cannot be updated after being finalized.)DOC"); cannot be updated after being finalized.)DOC");
pe.def(py::init<const std::vector<platform::Place> &, pe.def(py::init<const std::vector<platform::Place> &,
const std::unordered_set<std::string> &,
const std::unordered_set<std::string> &, const ProgramDesc &, const std::unordered_set<std::string> &, const ProgramDesc &,
const std::string &, Scope *, std::vector<Scope *> &, const std::string &, Scope *, std::vector<Scope *> &,
const ExecutionStrategy &, const BuildStrategy &, size_t, const ExecutionStrategy &, const BuildStrategy &, size_t,
......
...@@ -509,11 +509,11 @@ function assert_api_spec_approvals() { ...@@ -509,11 +509,11 @@ function assert_api_spec_approvals() {
if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then if [ ${API_CHANGE} ] && [ "${GIT_PR_ID}" != "" ]; then
# NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable. # NOTE: per_page=10000 should be ok for all cases, a PR review > 10000 is not human readable.
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433` python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803`
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}" echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then if [ "${APPROVALS}" == "FALSE" ]; then
echo "You must have at least 2 approvals for the api change! ${API_FILE}" echo "You must have panyx0718 approval for the api change! ${API_FILE}"
exit 1 exit 1
fi fi
fi fi
done done
...@@ -521,11 +521,11 @@ function assert_api_spec_approvals() { ...@@ -521,11 +521,11 @@ function assert_api_spec_approvals() {
HAS_CONST_CAST=`git diff -U0 upstream/$BRANCH |grep -o -m 1 "const_cast" || true` HAS_CONST_CAST=`git diff -U0 upstream/$BRANCH |grep -o -m 1 "const_cast" || true`
if [ ${HAS_CONST_CAST} ] && [ "${GIT_PR_ID}" != "" ]; then if [ ${HAS_CONST_CAST} ] && [ "${GIT_PR_ID}" != "" ]; then
APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \ APPROVALS=`curl -H "Authorization: token ${GITHUB_API_TOKEN}" https://api.github.com/repos/PaddlePaddle/Paddle/pulls/${GIT_PR_ID}/reviews?per_page=10000 | \
python ${PADDLE_ROOT}/tools/check_pr_approval.py 2 7845005 2887803 728699 13348433` python ${PADDLE_ROOT}/tools/check_pr_approval.py 1 2887803`
echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}" echo "current pr ${GIT_PR_ID} got approvals: ${APPROVALS}"
if [ "${APPROVALS}" == "FALSE" ]; then if [ "${APPROVALS}" == "FALSE" ]; then
echo "You must have at least 2 approvals for the const_cast" echo "You must have panyx0718 approval for the const_cast"
exit 1 exit 1
fi fi
fi fi
......
...@@ -489,8 +489,11 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -489,8 +489,11 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
grad_to_var = dict() grad_to_var = dict()
op_desc = _create_op_desc_( op_desc = _create_op_desc_(
"fill_constant", {}, {"Out": [_append_grad_suffix_(loss.name)]}, { "fill_constant",
"shape": [1], {},
{"Out": [_append_grad_suffix_(loss.name)]},
{
"shape": [1], # TODO(panyx0718): This can be loss.shape.
"value": 1.0, "value": 1.0,
"dtype": loss.dtype, "dtype": loss.dtype,
"force_cpu": False, "force_cpu": False,
......
...@@ -22,9 +22,12 @@ from . import op_frequence ...@@ -22,9 +22,12 @@ from . import op_frequence
from .op_frequence import * from .op_frequence import *
from . import quantize from . import quantize
from .quantize import * from .quantize import *
from . import utils
from .utils import *
__all__ = [] __all__ = []
__all__ += decoder.__all__ __all__ += decoder.__all__
__all__ += memory_usage_calc.__all__ __all__ += memory_usage_calc.__all__
__all__ += op_frequence.__all__ __all__ += op_frequence.__all__
__all__ += quantize.__all__ __all__ += quantize.__all__
__all__ += utils.__all__
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
#from . import lookup_table_utils from . import lookup_table_utils
#from .lookup_table_utils import * from .lookup_table_utils import *
from . import hdfs_utils from . import hdfs_utils
from .hdfs_utils import * from .hdfs_utils import *
#__all__ = lookup_table_utils.__all__ __all__ = []
__all__ = hdfs_utils.__all__ __all__ += lookup_table_utils.__all__
__all__ += hdfs_utils.__all__
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""HDFS Utils""" """HDFS Utils"""
import os import os
import sys
import subprocess import subprocess
import multiprocessing import multiprocessing
from datetime import datetime from datetime import datetime
...@@ -24,7 +25,7 @@ import errno ...@@ -24,7 +25,7 @@ import errno
import logging import logging
__all__ = ["HDFSClient", "multi_download"] __all__ = ["HDFSClient", "multi_download", "multi_upload"]
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
_logger = logging.getLogger("hdfs_utils") _logger = logging.getLogger("hdfs_utils")
...@@ -93,13 +94,15 @@ class HDFSClient(object): ...@@ -93,13 +94,15 @@ class HDFSClient(object):
def upload(self, hdfs_path, local_path, overwrite=False, retry_times=5): def upload(self, hdfs_path, local_path, overwrite=False, retry_times=5):
""" """
upload the local file to hdfs upload the local file to hdfs
Args:
hdfs_path: hdfs path, target path Args:
local_path: local file path, source path hdfs_path(str): the hdfs file path
overwrite: will overwrite the original file local_path(str): the local file path
retry_times: max times retry to upload overwrite(bool|None): will overwrite the file on HDFS or not
Returns: retry_times(int|5): retry times
Returns:
True or False True or False
""" """
assert hdfs_path is not None assert hdfs_path is not None
...@@ -109,7 +112,7 @@ class HDFSClient(object): ...@@ -109,7 +112,7 @@ class HDFSClient(object):
_logger.warn( _logger.warn(
"The Local path: {} is dir and I will support it later, return". "The Local path: {} is dir and I will support it later, return".
format(local_path)) format(local_path))
return return False
base = os.path.basename(local_path) base = os.path.basename(local_path)
if not self.is_exist(hdfs_path): if not self.is_exist(hdfs_path):
...@@ -141,14 +144,16 @@ class HDFSClient(object): ...@@ -141,14 +144,16 @@ class HDFSClient(object):
def download(self, hdfs_path, local_path, overwrite=False, unzip=False): def download(self, hdfs_path, local_path, overwrite=False, unzip=False):
""" """
download from hdfs download file from HDFS
Args:
hdfs_path: hdfs path, target path Args:
local_path: local file path, source path hdfs_path(str): the hdfs file path
overwrite: will remove original file and overwrite it. local_path(str): the local file path
unzip: ignore this param overwrite(bool|None): will overwrite the file on HDFS or not
Returns unzip(bool|False): if the download file is compressed by zip, unzip it or not.
True or False
Returns:
True or False
""" """
_logger.info('Downloading %r to %r.', hdfs_path, local_path) _logger.info('Downloading %r to %r.', hdfs_path, local_path)
_logger.info('Download of %s to %r complete.', hdfs_path, local_path) _logger.info('Download of %s to %r complete.', hdfs_path, local_path)
...@@ -188,13 +193,13 @@ class HDFSClient(object): ...@@ -188,13 +193,13 @@ class HDFSClient(object):
def is_exist(self, hdfs_path=None): def is_exist(self, hdfs_path=None):
""" """
whether the remote hdfs path exists? whether the remote HDFS path exists
Args:
hdfs_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp) Args:
fs_name: The default values are the same as in the job configuration hdfs_path(str): the hdfs file path
fs_ugi: The default values are the same as in the job configuration
Returns: Returns:
True or False True or False
""" """
exist_cmd = ['-test', '-e', hdfs_path] exist_cmd = ['-test', '-e', hdfs_path]
returncode, output, errors = self.__run_hdfs_cmd( returncode, output, errors = self.__run_hdfs_cmd(
...@@ -211,13 +216,13 @@ class HDFSClient(object): ...@@ -211,13 +216,13 @@ class HDFSClient(object):
def is_dir(self, hdfs_path=None): def is_dir(self, hdfs_path=None):
""" """
whether the remote hdfs path exists? whether the remote HDFS path is directory
Args:
remote_file_path: default value(${OUTPUT_PATH}/${SYS_USER_ID}/${SYS_JOB_ID}/tmp) Args:
fs_name: The default values are the same as in the job configuration hdfs_path(str): the hdfs file path
fs_ugi: The default values are the same as in the job configuration
Returns: Returns:
True or False True or False
""" """
if not self.is_exist(hdfs_path): if not self.is_exist(hdfs_path):
...@@ -237,17 +242,17 @@ class HDFSClient(object): ...@@ -237,17 +242,17 @@ class HDFSClient(object):
def delete(self, hdfs_path): def delete(self, hdfs_path):
""" """
Remove a file or directory from HDFS. Remove a file or directory from HDFS.
whether the remote HDFS path exists
Args: Args:
param hdfs_path: HDFS path. hdfs_path: HDFS path.
param recursive: Recursively delete files and directories. By default,
this method will raise an :class:`HdfsError` if trying to delete a
non-empty directory.
Returns: Returns:
True or False
This function returns `True` if the deletion was successful and `False` if This function returns `True` if the deletion was successful and `False` if
no file or directory previously existed at `hdfs_path`. no file or directory previously existed at `hdfs_path`.
""" """
_logger.info('Deleting %r.', hdfs_path) _logger.info('Deleting %r.', hdfs_path)
...@@ -273,16 +278,14 @@ class HDFSClient(object): ...@@ -273,16 +278,14 @@ class HDFSClient(object):
def rename(self, hdfs_src_path, hdfs_dst_path, overwrite=False): def rename(self, hdfs_src_path, hdfs_dst_path, overwrite=False):
""" """
Rename a file or folder. Move a file or folder on HDFS.
Args:
:param hdfs_src_path: Source path. Args:
:param hdfs_dst_path: Destination path. If the path already exists and is hdfs_path(str): HDFS path.
a directory, the source will be moved into it. If the path exists and is overwrite(bool|False): If the path already exists and overwrite is False, will return False.
a file, or if a parent destination directory is missing, this method will
raise an :class:`HdfsError`.
Returns: Returns:
This function returns `True` if the rename was successful and `False` if True or False
rename was faild.
""" """
assert hdfs_src_path is not None assert hdfs_src_path is not None
assert hdfs_dst_path is not None assert hdfs_dst_path is not None
...@@ -320,17 +323,20 @@ class HDFSClient(object): ...@@ -320,17 +323,20 @@ class HDFSClient(object):
raise raise
def makedirs(self, hdfs_path): def makedirs(self, hdfs_path):
"""Create a remote directory, recursively if necessary. """
Create a remote directory, recursively if necessary.
Args: Args:
:param hdfs_path: Remote path. Intermediate directories will be created hdfs_path(str): Remote path. Intermediate directories will be created appropriately.
appropriately.
Returns: Returns:
True if make a directories was successful, False when make a directiries was failed. True or False
""" """
_logger.info('Creating directories to %r.', hdfs_path) _logger.info('Creating directories to %r.', hdfs_path)
assert hdfs_path is not None assert hdfs_path is not None
if self.is_exist(hdfs_path): if self.is_exist(hdfs_path):
_logger.error("HDFS path is exist: {}".format(hdfs_path))
return return
mkdirs_commands = ['-mkdir', hdfs_path] mkdirs_commands = ['-mkdir', hdfs_path]
...@@ -346,11 +352,13 @@ class HDFSClient(object): ...@@ -346,11 +352,13 @@ class HDFSClient(object):
def ls(self, hdfs_path): def ls(self, hdfs_path):
""" """
ls a hdfs_path. ls directory contents about HDFS hdfs_path
Args:
:param hdfs_path: hdfs_path will be ls. Args:
hdfs_path(str): Remote HDFS path will be ls.
Returns: Returns:
This function returns a `list` that contaion all files in the hdfs_path. List: a contents list about hdfs_path.
""" """
assert hdfs_path is not None assert hdfs_path is not None
...@@ -378,11 +386,15 @@ class HDFSClient(object): ...@@ -378,11 +386,15 @@ class HDFSClient(object):
def lsr(self, hdfs_path, only_file=True, sort=True): def lsr(self, hdfs_path, only_file=True, sort=True):
""" """
ls a hdfs_path sort by time. list directory contents about HDFS hdfs_path recursively
Args:
:param hdfs_path: hdfs_path will be ls. Args:
hdfs_path(str): Remote HDFS path.
only_file(bool|True): will discard folders.
sort(bool|True): will be sorted by create time.
Returns: Returns:
This function returns a `list` that contaion all files sorted by time in the hdfs_path. List: a contents list about hdfs_path.
""" """
def sort_by_time(v1, v2): def sort_by_time(v1, v2):
...@@ -422,21 +434,106 @@ class HDFSClient(object): ...@@ -422,21 +434,106 @@ class HDFSClient(object):
return ret_lines return ret_lines
def multi_download(client,
hdfs_path,
local_path,
trainer_id,
trainers,
multi_processes=5):
"""
Download files from HDFS using multi process.
Args:
client(HDFSClient): instance of HDFSClient
hdfs_path(str): path on hdfs
local_path(str): path on local
trainer_id(int): current trainer id
trainers(int): all trainers number
multi_processes(int|5): the download data process at the same time, default=5
Returns:
List:
Download files in local folder.
"""
def __subprocess_download(datas):
for data in datas:
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
if re_path == os.curdir:
sub_local_re_path = local_path
else:
sub_local_re_path = os.path.join(local_path, re_path)
client.download(data, sub_local_re_path)
assert isinstance(client, HDFSClient)
client.make_local_dirs(local_path)
_logger.info("Make local dir {} successfully".format(local_path))
all_need_download = client.lsr(hdfs_path, sort=True)
need_download = all_need_download[trainer_id::trainers]
_logger.info("Get {} files From all {} files need to be download from {}".
format(len(need_download), len(all_need_download), hdfs_path))
_logger.info("Start {} multi process to download datas".format(
multi_processes))
procs = []
for i in range(multi_processes):
process_datas = need_download[i::multi_processes]
p = multiprocessing.Process(
target=__subprocess_download, args=(process_datas, ))
procs.append(p)
p.start()
# complete the processes
for proc in procs:
proc.join()
_logger.info("Finish {} multi process to download datas".format(
multi_processes))
local_downloads = []
for data in need_download:
data_name = os.path.basename(data)
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
if re_path == os.curdir:
local_re_path = os.path.join(local_path, data_name)
else:
local_re_path = os.path.join(local_path, re_path, data_name)
local_downloads.append(local_re_path)
return local_downloads
def getfilelist(path):
rlist = []
for dir, folder, file in os.walk(path):
for i in file:
t = os.path.join(dir, i)
rlist.append(t)
for r in rlist:
print(r)
def multi_upload(client, def multi_upload(client,
hdfs_path, hdfs_path,
local_path, local_path,
multi_processes=5, multi_processes=5,
overwrite=False): overwrite=False,
sync=True):
""" """
Upload file to hdfs. Upload files to HDFS using multi process.
Args: Args:
:param overwrite: will overwrite hdfs file or not client(HDFSClient): instance of HDFSClient
:param multi_processes: the upload data process at the same time, default=5 hdfs_path(str): path on hdfs
:param client: instance of HDFSClient local_path(str): path on local
:param hdfs_path: path on hdfs multi_processes(int|5): the upload data process at the same time, default=5
:param local_path: path on local overwrite(bool|False): will overwrite file on HDFS or not
sync(bool|True): upload files sync or not.
Returns: Returns:
None
""" """
def __subprocess_upload(datas): def __subprocess_upload(datas):
...@@ -446,13 +543,6 @@ def multi_upload(client, ...@@ -446,13 +543,6 @@ def multi_upload(client,
client.upload(hdfs_re_path, data, overwrite, retry_times=5) client.upload(hdfs_re_path, data, overwrite, retry_times=5)
def get_local_files(path): def get_local_files(path):
"""
Get all local files
Args:
path: local file path
Returns:
A list that contation all files in the path.
"""
rlist = [] rlist = []
if not os.path.isdir(path): if not os.path.isdir(path):
...@@ -488,71 +578,6 @@ def multi_upload(client, ...@@ -488,71 +578,6 @@ def multi_upload(client,
multi_processes)) multi_processes))
def multi_download(client,
hdfs_path,
local_path,
trainer_id,
trainers,
file_cnt,
multi_processes=5):
"""
multi_download
Args:
:param client: instance of HDFSClient
:param hdfs_path: path on hdfs
:param local_path: path on local
:param trainer_id: current trainer id
:param trainers: all trainers number
:param file_cnt: all file number
:param multi_processes: the download data process at the same time, default=5
:return: None
Returns:
A list that be downloaded.
"""
def __subprocess_download(datas):
for data in datas:
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
local_re_path = os.path.join(local_path, re_path)
client.download(data, local_re_path)
assert isinstance(client, HDFSClient)
client.make_local_dirs(local_path)
_logger.info("Make local dir {} successfully".format(local_path))
all_need_download = client.lsr(hdfs_path, sort=True)[:file_cnt]
need_download = all_need_download[trainer_id::trainers]
_logger.info("Get {} files From all {} files need to be download from {}".
format(len(need_download), len(all_need_download), hdfs_path))
_logger.info("Start {} multi process to download datas".format(
multi_processes))
procs = []
for i in range(multi_processes):
process_datas = need_download[i::multi_processes]
p = multiprocessing.Process(
target=__subprocess_download, args=(process_datas, ))
procs.append(p)
p.start()
# complete the processes
for proc in procs:
proc.join()
_logger.info("Finish {} multi process to download datas".format(
multi_processes))
local_downloads = []
for data in need_download:
data_name = os.path.basename(data)
re_path = os.path.relpath(os.path.dirname(data), hdfs_path)
local_re_path = os.path.join(local_path, re_path, data_name)
local_downloads.append(local_re_path)
return local_downloads
if __name__ == "__main__": if __name__ == "__main__":
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
......
...@@ -18,14 +18,12 @@ import os ...@@ -18,14 +18,12 @@ import os
import time import time
import logging import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import io from paddle.fluid import io
from paddle.fluid import Program from paddle.fluid import Program
__all__ = [ __all__ = [
"load_inference_model", "load_persistable_vars", "load_persistables_for_increment", "load_persistables_for_inference",
"convert_dist_to_sparse_program" "convert_dist_to_sparse_program"
] ]
...@@ -80,19 +78,28 @@ def __get_prefetch_op_tuples(main_program): ...@@ -80,19 +78,28 @@ def __get_prefetch_op_tuples(main_program):
return prefetch_op_tuples return prefetch_op_tuples
def convert_dist_to_sparse_program(main_program): def convert_dist_to_sparse_program(program):
if not main_program._distributed_lookup_table: """
WARNING: this function will only be used for distributed training with distributed lookup table.
when we train model with distributed lookup table but want to do the local inference, we can use
this function to convert the train program with distributed lookup table to sparse lookup table.
:param program(Program): the program must be the trainer program, which will be get by the distribute transpiler.
:return:
program: The `program` is a Program, it's the program replace distributed lookup table to sparse lookup table.
"""
if not program._distributed_lookup_table:
_logger.warn( _logger.warn(
"There are no distributed lookup tables need to be converted") "There are no distributed lookup tables need to be converted")
return return
# create table param and grad var in pserver program # create table param and grad var in pserver program
origin_emb_var = "{}.origin".format(main_program._distributed_lookup_table) origin_emb_var = "{}.origin".format(program._distributed_lookup_table)
emb_var = main_program._distributed_lookup_table emb_var = program._distributed_lookup_table
main_program.global_block()._rename_var(emb_var, origin_emb_var) program.global_block()._rename_var(emb_var, origin_emb_var)
origin_param_var = main_program.global_block().vars[origin_emb_var] origin_param_var = program.global_block().vars[origin_emb_var]
param_var = main_program.global_block().create_var( param_var = program.global_block().create_var(
name=emb_var, name=emb_var,
shape=origin_param_var.shape, shape=origin_param_var.shape,
dtype=origin_param_var.dtype, dtype=origin_param_var.dtype,
...@@ -100,28 +107,28 @@ def convert_dist_to_sparse_program(main_program): ...@@ -100,28 +107,28 @@ def convert_dist_to_sparse_program(main_program):
persistable=True) persistable=True)
# parameter must be selected rows # parameter must be selected rows
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
main_program._sync_with_cpp() program._sync_with_cpp()
prefetch_op_tuples = __get_prefetch_op_tuples(main_program) prefetch_op_tuples = __get_prefetch_op_tuples(program)
split_ids_id = prefetch_op_tuples[0] split_ids_id = prefetch_op_tuples[0]
for idx in range(split_ids_id + 2, split_ids_id - 1, -1): for idx in range(split_ids_id + 2, split_ids_id - 1, -1):
main_program.global_block()._remove_op(idx) program.global_block()._remove_op(idx)
main_program.desc.flush() program.desc.flush()
in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2]) in_out_pairs = zip(prefetch_op_tuples[1], prefetch_op_tuples[2])
for in_out_pair in in_out_pairs: for in_out_pair in in_out_pairs:
idx = split_ids_id idx = split_ids_id
ids = main_program.global_block().vars[in_out_pair[0]] ids = program.global_block().vars[in_out_pair[0]]
out = main_program.global_block().vars[in_out_pair[1]] out = program.global_block().vars[in_out_pair[1]]
__insert_lookup_sparse_table_op(main_program, idx, ids, param_var, out) __insert_lookup_sparse_table_op(program, idx, ids, param_var, out)
main_program.desc.flush() program.desc.flush()
return main_program return program
def load_persistable_vars(executor, dirname, program, lookup_table_var): def _load_persistable_vars(executor, dirname, program, lookup_table_vars):
def _is_checkpoint_var(exclude_fluid_vars=None): def _is_checkpoint_var(exclude_fluid_vars=None):
""" """
the checkpoint will not save or load all the variables. the checkpoint will not save or load all the variables.
...@@ -159,8 +166,82 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var): ...@@ -159,8 +166,82 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var):
return is_valid return is_valid
def _load_lookup_table_vars(executor, dirname, main_program, io.load_vars(
lookup_table_vars): executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var(lookup_table_vars),
filename=None)
def load_persistables_for_increment(dirname, executor, program,
lookup_table_var, lookup_table_var_path):
"""
WARNING: this function will only be used for distributed training with distributed lookup table.
for increment trainning, the pserver will not only load dense variables,
but also load the suitable lookup table var. Because of slice lookup table
var with HASH, we must load the correct slice var.
:param dirname(str): The directory path
:param executor(Executor): The executor to run for loading inference model.
:param program(Program): The parameter server program, which will run on Pserver.
:param lookup_table_var: the distributed lookup tables var name.
:param lookup_table_var_path: the the distributed lookup tables var location.
:return: None
"""
def __load_lookup_table_vars(executor, main_program, lookup_table_var,
lookup_table_var_path):
emb_var = main_program.global_block().var(lookup_table_var)
load_program = Program()
load_block = load_program.global_block()
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [emb_var]},
attrs={'file_path': lookup_table_var_path})
executor.run(load_program)
if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname)
if not os.path.exists(lookup_table_var_path):
raise ValueError("There is no file named '%s'", lookup_table_var_path)
if not isinstance(program, Program):
raise ValueError("program must be an instance of fluid.Program")
_logger.info("Start Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
_load_persistable_vars(executor, dirname, program, [lookup_table_var])
__load_lookup_table_vars(executor, program, lookup_table_var,
lookup_table_var_path)
_logger.info("Finish Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
def load_persistables_for_inference(dirname, executor, program,
lookup_table_var_name):
"""
WARNING: this function will only be used for inference with distributed lookup table.
Inference with distributed lookup table is a little funky, this function will load distributed
lookup table vars into sparse var, can be used in local inference mode.
:param dirname(str): The directory path
:param executor(Executor): The executor to run for loading inference model.
:param program(Program): The parameter server program, which will run on Pserver.
:param lookup_table_var_name: the distributed lookup tables var name.
:return: None
"""
def __load_lookup_table_vars(executor, dirname, main_program,
lookup_table_vars):
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
...@@ -209,48 +290,34 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var): ...@@ -209,48 +290,34 @@ def load_persistable_vars(executor, dirname, program, lookup_table_var):
global_block.append_op(type='delete_var', inputs={'X': sums}) global_block.append_op(type='delete_var', inputs={'X': sums})
executor.run(convert_program) executor.run(convert_program)
_logger.info("Start Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
lookup_table_vars = [lookup_table_var]
io.load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var(lookup_table_vars),
filename=None)
_load_lookup_table_vars(executor, dirname, program, lookup_table_vars)
_logger.info("Finish Load Sparse Program With "
"Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
def load_inference_model(dirname, executor, lookup_table_var_name):
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
raise ValueError("There is no directory named '%s'", dirname) raise ValueError("There is no directory named '%s'", dirname)
local_model = os.path.join(dirname, model_filename) if program:
if not isinstance(program, Program):
raise ValueError("program must be an instance of fluid.Program")
else:
local_model = os.path.join(dirname, model_filename)
with open(local_model, "rb") as f: with open(local_model, "rb") as f:
program_desc_str = f.read() program_desc_str = f.read()
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
if not core._is_program_version_supported(program._version()): if not core._is_program_version_supported(program._version()):
raise ValueError("Unsupported program version: %d\n" % raise ValueError("Unsupported program version: %d\n" %
program._version()) program._version())
# Binary data also need version. _logger.info("Start Load Sparse Program With "
load_persistable_vars(executor, dirname, program, lookup_table_var_name) "Distributed Lookup Table Vars from {}, time = {}".format(
dirname, time.ctime()))
_load_persistable_vars(executor, dirname, program, [lookup_table_var_name])
__load_lookup_table_vars(executor, dirname, program,
[lookup_table_var_name])
feed_target_names = program.desc.get_feed_target_names() _logger.info("Finish Load Sparse Program With "
fetch_target_names = program.desc.get_fetch_target_names() "Distributed Lookup Table Vars from {}, time = {}".format(
fetch_targets = [ dirname, time.ctime()))
program.global_block().var(name) for name in fetch_target_names
]
return [program, feed_target_names, fetch_targets] return program
...@@ -1324,6 +1324,9 @@ class Block(object): ...@@ -1324,6 +1324,9 @@ class Block(object):
def _prepend_op(self, *args, **kwargs): def _prepend_op(self, *args, **kwargs):
op_desc = self.desc._prepend_op() op_desc = self.desc._prepend_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
if _in_imperative_mode():
_imperative_tracer().trace(op.iop, [v._ivar for v in op.inputs],
[v._ivar for v in op.outputs], self.desc)
self.ops.insert(0, op) self.ops.insert(0, op)
return op return op
......
...@@ -28,7 +28,8 @@ def enabled(): ...@@ -28,7 +28,8 @@ def enabled():
def guard(): def guard():
train = framework.Program() train = framework.Program()
startup = framework.Program() startup = framework.Program()
tracer = core.Tracer(train.current_block().desc) tracer = core.Tracer(train.current_block().desc,
startup.current_block().desc)
with framework.program_guard(train, startup): with framework.program_guard(train, startup):
with framework.unique_name.guard(): with framework.unique_name.guard():
with framework._imperative_guard(tracer): with framework._imperative_guard(tracer):
......
...@@ -25,11 +25,9 @@ __all__ = ['PyLayer'] ...@@ -25,11 +25,9 @@ __all__ = ['PyLayer']
class PyLayer(core.Layer): class PyLayer(core.Layer):
def __init__(self): def __init__(self):
pass self._built = False
def __call__(self, inputs): def __call__(self, inputs):
# TODO(panyx0718): Support declarative mode as well.
assert base.enabled()
if not isinstance(inputs, list) and not isinstance(inputs, tuple): if not isinstance(inputs, list) and not isinstance(inputs, tuple):
inputs = [inputs] inputs = [inputs]
...@@ -37,8 +35,15 @@ class PyLayer(core.Layer): ...@@ -37,8 +35,15 @@ class PyLayer(core.Layer):
for x in inputs: for x in inputs:
py_var = base.to_variable(x) py_var = base.to_variable(x)
var_inputs.append(py_var) var_inputs.append(py_var)
if not self._built:
self._build_once(inputs)
self._built = True
outputs = self.forward(var_inputs) outputs = self.forward(var_inputs)
return outputs return outputs
def _build_once(self, inputs):
pass
def forward(self, inputs): def forward(self, inputs):
return [] return []
...@@ -18,7 +18,9 @@ All layers just related to the neural network. ...@@ -18,7 +18,9 @@ All layers just related to the neural network.
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import six
import os import os
import inspect
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant from ..initializer import Normal, Constant
from ..framework import Variable, OpProtoHolder from ..framework import Variable, OpProtoHolder
...@@ -29,6 +31,7 @@ from . import utils ...@@ -29,6 +31,7 @@ from . import utils
from .. import unique_name from .. import unique_name
from functools import reduce from functools import reduce
from .. import core from .. import core
from ..imperative import layers
__all__ = [ __all__ = [
'fc', 'fc',
...@@ -175,6 +178,7 @@ __all__ = [ ...@@ -175,6 +178,7 @@ __all__ = [
'merge_selected_rows', 'merge_selected_rows',
'get_tensor_from_selected_rows', 'get_tensor_from_selected_rows',
'lstm', 'lstm',
'py_func',
'psroi_pool', 'psroi_pool',
'huber_loss', 'huber_loss',
] ]
...@@ -9326,6 +9330,224 @@ def get_tensor_from_selected_rows(x, name=None): ...@@ -9326,6 +9330,224 @@ def get_tensor_from_selected_rows(x, name=None):
return out return out
class PyFuncRegistry(object):
_register_funcs = []
def __init__(self, func):
if func is None or not callable(func):
raise TypeError('func must be a Python function')
self._func = func
# find named args using reflection
args = inspect.getargspec(self._func)
if len(args[0]) == 0 and args[1] is None and args[2] is None:
# Function with no inputs
self._named_args = None
else:
self._named_args = args[0]
self._id = core._append_python_callable_object_and_return_id(self)
'''
Why record self here?
1. For debug usage. Users can call
:code:`py_func.registered_func(idx)` method
to find the registered function corresponding
to :code:`idx`.
2. For increasing reference count of self.
It seems that to release Python object
whose reference count is 1 would cause
segmentation fault error in C++ side.
May be lack of Python GC in C++ side?
'''
PyFuncRegistry._register_funcs.append(self)
@classmethod
def registered_func(cls, idx):
return cls._register_funcs[idx]._func
@classmethod
def registered_func_num(cls):
return len(cls._register_funcs)
@property
def id(self):
return self._id
def __call__(self, *args):
if self._named_args is None:
func_ret = self._func()
else:
kwargs = dict()
idx = 0
for arg in self._named_args:
kwargs[arg] = args[idx]
idx += 1
func_ret = self._func(*args[idx:], **kwargs)
if not isinstance(func_ret, (list, tuple)):
func_ret = (func_ret, )
ret = []
for each_ret in func_ret:
if each_ret is None or isinstance(each_ret, core.LoDTensor):
ret.append(each_ret)
continue
if not isinstance(each_ret, np.ndarray):
each_ret = np.array(each_ret)
tensor = core.LoDTensor()
tensor.set(each_ret, core.CPUPlace())
ret.append(tensor)
return tuple(ret)
@templatedoc()
def py_func(func, x, out, backward_func=None, skip_vars_in_backward_input=None):
"""
PyFunc Operator.
User can use :code:`py_func` to register operators in Python side.
The inputs of :code:`func` is :code:`LoDTensor` and outputs can be
numpy array or :code:`LoDTensor`. Paddle would call the registered
:code:`func` in forward part, and call :code:`backward_func` in
backward part (if :code:`backward_func` is not None).
User should set the right data type and shape of :code:`out` before
calling this function. However, data types and shapes of gradients of
:code:`out` and :code:`x` would be inferred automatically.
Input orders of :code:`backward_func` would be: forward inputs
:code:`x`, forward outputs :code:`out` and backward input gradients of
:code:`out`. If some variables of :code:`out` have no gradient, the input
tensor would be None in Python side. If some variables of :code:`in` have
no gradient, users should return None.
This function can also be used to debug the running network. User can
add a :code:`py_func` operator without output, and print input
:code:`x` inside :code:`func`.
Args:
func (callable): forward Python function.
x (Variable|list(Variable)|tuple(Variable)): inputs of :code:`func`.
out (Variable|list(Variable)|tuple(Variable)): outputs of :code:`func`.
Paddle cannot infer shapes and data types of :code:`out`. Users
should create :code:`out` beforehand.
backward_func (callable|None): backward Python function.
None means no backward. Default None.
skip_vars_in_backward_input (Variable|list(Variable)|tuple(Variable)):
Variables that are not needed in :code:`backward_func` inputs.
These variables must be any of :code:`x` and :code:`out`.
If set, these vars would not be inputs of :code:`backward_func`,
Only useful when :code:`backward_func` is not None. Default None.
Returns:
out (Variable|list(Variable)|tuple(Variable)): input :code:`out`
Examples:
>>> import paddle.fluid as fluid
>>> import six
>>>
>>> def create_tmp_var(name, dtype, shape):
>>> return fluid.default_main_program().current_block().create_var(
>>> name=name, dtype=dtype, shape=shape)
>>>
>>> # tanh activation has been provided by Paddle C++ op
>>> # Here, we only use tanh to be an example to show the usage
>>> # of py_func
>>> def tanh(x):
>>> return np.tanh(x)
>>>
>>> # forward input x is skipped
>>> def tanh_grad(y, dy):
>>> return np.array(dy) * (1 - np.square(np.array(y)))
>>>
>>> def debug_func(x):
>>> print(x)
>>>
>>> def simple_net(img, label):
>>> hidden = img
>>> for idx in six.moves.range(4):
>>> hidden = fluid.layers.fc(hidden, size=200)
>>> new_hidden = create_tmp_var(name='hidden_{}'.format(idx),
>>> dtype=hidden.dtype, shape=hidden.shape)
>>>
>>> # user-defined layers with forward and backward
>>> hidden = fluid.layers.py_func(func=tanh, x=hidden,
>>> out=new_hidden, backward_func=tanh_grad,
>>> skip_vars_in_backward_input=hidden)
>>>
>>> # user-defined debug layers to print variables
>>> fluid.layers.py_func(func=debug_func, x=hidden, out=None)
>>>
>>> prediction = fluid.layers.fc(hidden, size=10, act='softmax')
>>> loss = fluid.layers.cross_entropy(input=prediction, label=label)
>>> return fluid.layers.mean(loss)
"""
helper = LayerHelper('py_func', **locals())
if x is None:
x = []
elif isinstance(x, Variable):
x = [x]
elif not isinstance(x, (list, tuple)):
raise TypeError('Input must be Variable/list(Variable)/tuple(Variable)')
if out is None:
out_list = []
elif isinstance(out, Variable):
out_list = [out]
elif isinstance(out, (list, tuple)):
out_list = out
else:
raise TypeError(
'Output must be Variable/list(Variable)/tuple(Variable)')
fwd_func_id = PyFuncRegistry(func).id
bwd_func_id = PyFuncRegistry(
backward_func).id if backward_func is not None else -1
for each_out in out_list:
if len(each_out.shape) == 0:
raise ValueError(
'Output shapes of py_func op should be provided by users manually'
)
backward_skip_vars = set()
if backward_func is not None and skip_vars_in_backward_input is not None:
if isinstance(skip_vars_in_backward_input, Variable):
skip_vars_in_backward_input = [skip_vars_in_backward_input]
fwd_in_out = [v.name for v in x]
fwd_in_out.extend([v.name for v in out_list])
fwd_in_out = set(fwd_in_out)
backward_skip_vars = set()
for v in skip_vars_in_backward_input:
if not v.name in fwd_in_out:
raise ValueError(
'Variable {} is not found in forward inputs and outputs'
.format(v.name))
backward_skip_vars.add(v.name)
helper.append_op(
type='py_func',
inputs={'X': x},
outputs={'Out': out_list},
attrs={
'forward_callable_id': fwd_func_id,
'backward_callable_id': bwd_func_id,
'backward_skip_vars': list(backward_skip_vars)
})
return out
# For debug usage
py_func.registered_func = PyFuncRegistry.registered_func
py_func.registered_func_num = PyFuncRegistry.registered_func_num
@templatedoc() @templatedoc()
def psroi_pool(input, def psroi_pool(input,
rois, rois,
...@@ -9426,3 +9648,47 @@ def huber_loss(input, label, delta): ...@@ -9426,3 +9648,47 @@ def huber_loss(input, label, delta):
'Residual': residual}, 'Residual': residual},
attrs={'delta': delta}) attrs={'delta': delta})
return out return out
class FC(layers.PyLayer):
def __init__(self,
size,
param_attr=None,
num_flatten_dims=1,
dtype=core.VarDesc.VarType.FP32):
super(FC, self).__init__()
self._size = size
self._num_flatten_dims = num_flatten_dims
self._dtype = dtype
self._helper = LayerHelper('FC', param_attr=param_attr)
def _build_once(self, inputs):
input_shape = inputs[0].shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:], 1)
] + [self._size]
self._w = self._helper.create_parameter(
attr=self._helper.param_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=False)
def forward(self, inputs):
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": inputs[0],
"Y": self._w},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
out = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": [tmp]},
outputs={"Out": out},
attrs={"use_mkldnn": False})
return out
...@@ -92,35 +92,27 @@ class ParallelExecutor(object): ...@@ -92,35 +92,27 @@ class ParallelExecutor(object):
num_trainers=1, num_trainers=1,
trainer_id=0, trainer_id=0,
scope=None): scope=None):
# step1: get places, the places are used in run too.
self._places = [] self._places = []
self._act_places = []
if use_cuda: if use_cuda:
gpus = []
gpus_env = os.getenv("FLAGS_selected_gpus") gpus_env = os.getenv("FLAGS_selected_gpus")
if gpus_env: if gpus_env:
gpus = [int(s) for s in gpus_env.split(",")] gpus = [int(s) for s in gpus_env.split(",")]
else: else:
for i in six.moves.range(core.get_cuda_device_count()): gpus = [
gpus.append(i) i for i in six.moves.range(core.get_cuda_device_count())
for i in gpus: ]
p = core.Place() self._places = [core.CUDAPlace(i) for i in gpus]
self._act_places.append(core.CUDAPlace(i))
p.set_place(self._act_places[-1])
self._places.append(p)
else: else:
cpu_num = int( cpu_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
for i in six.moves.range(cpu_num): self._places = [core.CPUPlace() for _ in six.moves.range(cpu_num)]
p = core.Place()
self._act_places.append(core.CPUPlace())
p.set_place(self._act_places[-1])
self._places.append(p)
assert self._places, "no place for execution" assert self._places, "no place for execution"
# step2: init exec_strategy
if exec_strategy is None: if exec_strategy is None:
exec_strategy = ExecutionStrategy() exec_strategy = ExecutionStrategy()
exec_strategy.use_cuda = use_cuda exec_strategy.use_cuda = use_cuda
if exec_strategy.num_threads == 0: if exec_strategy.num_threads == 0:
if use_cuda: if use_cuda:
# Experiments on se-resnext shows that too many threads hurt # Experiments on se-resnext shows that too many threads hurt
...@@ -131,49 +123,54 @@ class ParallelExecutor(object): ...@@ -131,49 +123,54 @@ class ParallelExecutor(object):
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exec_strategy.num_threads = cpu_num * 2 exec_strategy.num_threads = cpu_num * 2
# step3: init build_strategy
if build_strategy is None: if build_strategy is None:
build_strategy = BuildStrategy() build_strategy = BuildStrategy()
build_strategy.num_trainers = num_trainers build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id build_strategy.trainer_id = trainer_id
main = main_program # step4: get main_program, scope, local_scopes
main = main if main else framework.default_main_program() main = main_program if main_program \
else framework.default_main_program()
scope = scope if scope is not None else executor.global_scope()
if share_vars_from and not isinstance(share_vars_from,
ParallelExecutor):
raise TypeError("share_vars_from must be ParallelExecutor.")
local_scopes = share_vars_from.executor.local_scopes()\
if share_vars_from else []
# step5: check trainers_endpoints, it is used for distribution.
trainers_endpoints = main._trainers_endpoints trainers_endpoints = main._trainers_endpoints
if num_trainers > 1 and trainers_endpoints: if num_trainers > 1 and trainers_endpoints:
assert num_trainers == len( assert num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)" trainers_endpoints), "num_trainers == len(end_points)"
build_strategy.trainers_endpoints = trainers_endpoints build_strategy.trainers_endpoints = trainers_endpoints
if scope == None: # step5: get persistable_vars, parameter_vars, places. persistable_vars
scope = executor.global_scope() # need be broadcast to other local_scope.
persistable_vars = set([
if share_vars_from and not isinstance(share_vars_from, cpt.to_text(v.name) for v in [
ParallelExecutor):
raise TypeError("share_vars_from must be ParallelExecutor.")
local_scopes = share_vars_from.executor.local_scopes(
) if share_vars_from else []
self.persistable_vars = [
v.name for v in [
var for var in main.list_vars() var for var in main.list_vars()
if var.persistable and var.type != core.VarDesc.VarType.RAW if var.persistable and var.type != core.VarDesc.VarType.RAW
] ]
] ])
def place_obj(place):
p = core.Place()
p.set_place(place)
return p
places = list(map(place_obj, self._places))
# step6: init ParallelExecutor
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
self._places, places, persistable_vars, main.desc,
set([
cpt.to_text(p.name)
for p in main.global_block().iter_parameters()
if not p.stop_gradient
]),
set(cpt.to_text(var) for var in self.persistable_vars), main.desc,
cpt.to_text(loss_name) cpt.to_text(loss_name)
if loss_name else six.u(''), scope, local_scopes, exec_strategy, if loss_name else six.u(''), scope, local_scopes, exec_strategy,
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True): def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
...@@ -261,7 +258,7 @@ class ParallelExecutor(object): ...@@ -261,7 +258,7 @@ class ParallelExecutor(object):
self.executor.feed_and_split_tensor_into_local_scopes( self.executor.feed_and_split_tensor_into_local_scopes(
feed_tensor_dict) feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple): elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(self._act_places): if len(feed) != len(self._places):
raise ValueError( raise ValueError(
"Feed a list of tensor, the list should be the same size as places" "Feed a list of tensor, the list should be the same size as places"
) )
...@@ -277,7 +274,7 @@ class ParallelExecutor(object): ...@@ -277,7 +274,7 @@ class ParallelExecutor(object):
tensor = each[feed_name] tensor = each[feed_name]
if not isinstance(tensor, core.LoDTensor): if not isinstance(tensor, core.LoDTensor):
tmp = core.LoDTensor() tmp = core.LoDTensor()
tmp.set(tensor, self._act_places[i]) tmp.set(tensor, self._places[i])
tensor = tmp tensor = tmp
res_dict[feed_name] = tensor res_dict[feed_name] = tensor
res.append(res_dict) res.append(res_dict)
...@@ -294,4 +291,4 @@ class ParallelExecutor(object): ...@@ -294,4 +291,4 @@ class ParallelExecutor(object):
@property @property
def device_count(self): def device_count(self):
return len(self._act_places) return len(self._places)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestRelu, TestTanh
class TestNGRAPHReluDim2(TestRelu):
def setUp(self):
super(TestNGRAPHReluDim2, self).setUp()
class TestNGRAPHTanhDim2(TestTanh):
def setUp(self):
super(TestNGRAPHTanhDim2, self).setUp()
class TestNGRAPHReluDim4(TestRelu):
def setUp(self):
super(TestNGRAPHReluDim4, self).setUp()
x = np.random.uniform(-1, 1, [2, 4, 3, 5]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
class TestNGRAPHTanhDim4(TestTanh):
def setUp(self):
super(TestNGRAPHTanhDim4, self).setUp()
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 4, 3, 5]).astype("float32")
}
self.outputs = {'Out': np.tanh(self.inputs['X'])}
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid.tests.unittests.test_mul_op import TestMulOp, TestMulOp2, TestFP16MulOp1, TestFP16MulOp2
class TestNGRAPHMulOp(TestMulOp):
def init_dtype_type(self):
pass
class TestNGRAPHMulOp2(TestMulOp2):
def init_dtype_type(self):
pass
class TestNGRAPHFP16MulOp1(TestFP16MulOp1):
def init_dtype_type(self):
pass
class TestNGRAPHFP16MulOp2(TestFP16MulOp2):
def init_dtype_type(self):
pass
if __name__ == "__main__":
unittest.main()
...@@ -16,7 +16,7 @@ from __future__ import print_function ...@@ -16,7 +16,7 @@ from __future__ import print_function
import unittest import unittest
from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride, TestWithGroup, TestWith1x1, TestWithInput1x1Filter1x1
class TestMKLDNN(TestConv2dOp): class TestMKLDNN(TestConv2dOp):
...@@ -37,5 +37,23 @@ class TestMKLDNNWithStride(TestWithStride): ...@@ -37,5 +37,23 @@ class TestMKLDNNWithStride(TestWithStride):
self.data_format = "NCHW" self.data_format = "NCHW"
class TestMKLDNNWithGroup(TestWithGroup):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWith1x1(TestWith1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -29,7 +29,7 @@ class TestGetTensorFromSelectedRows(unittest.TestCase): ...@@ -29,7 +29,7 @@ class TestGetTensorFromSelectedRows(unittest.TestCase):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
x_rows = [0, 5, 5, 4, 20] x_rows = [0, 5, 5, 4, 19]
height = 20 height = 20
row_numel = 2 row_numel = 2
......
...@@ -12,12 +12,23 @@ ...@@ -12,12 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import unittest import unittest
import sys
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.layers.nn import FC
@contextlib.contextmanager
def new_program_scope():
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
class MyLayer(fluid.imperative.PyLayer): class MyLayer(fluid.imperative.PyLayer):
...@@ -30,6 +41,23 @@ class MyLayer(fluid.imperative.PyLayer): ...@@ -30,6 +41,23 @@ class MyLayer(fluid.imperative.PyLayer):
return [fluid.layers.elementwise_mul(x, x)] return [fluid.layers.elementwise_mul(x, x)]
class MLP(fluid.imperative.PyLayer):
def __init__(self):
super(MLP, self).__init__()
self._fc1 = FC(3,
fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)))
self._fc2 = FC(4,
fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)))
def forward(self, inputs):
x = self._fc1(inputs[0])
x = self._fc2(x)
x = fluid.layers.reduce_sum(x)
return x
class TestImperative(unittest.TestCase): class TestImperative(unittest.TestCase):
def test_layer(self): def test_layer(self):
with fluid.imperative.guard(): with fluid.imperative.guard():
...@@ -39,13 +67,56 @@ class TestImperative(unittest.TestCase): ...@@ -39,13 +67,56 @@ class TestImperative(unittest.TestCase):
l.forward([]) l.forward([])
def test_layer_in_out(self): def test_layer_in_out(self):
np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32)
with fluid.imperative.guard(): with fluid.imperative.guard():
l = MyLayer() l = MyLayer()
x = l(np.array([1.0, 2.0, -1.0], dtype=np.float32))[0] x = l(np_inp)[0]
self.assertIsNotNone(x) self.assertIsNotNone(x)
sys.stderr.write("%s output: %s\n" % (x, x._numpy())) dy_out = x._numpy()
x._backward() x._backward()
sys.stderr.write("grad %s\n" % l._x_for_debug._gradient()) dy_grad = l._x_for_debug._gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[3], append_batch_size=False)
l = MyLayer()
x = l(inp)[0]
param_grads = fluid.backward.append_backward(
x, parameter_list=[l._x_for_debug.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
fetch_list=[x.name, param_grads[1].name])
self.assertTrue(np.allclose(dy_out, static_out))
self.assertTrue(np.allclose(dy_grad, static_grad))
def test_mlp(self):
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
with fluid.imperative.guard():
mlp = MLP()
out = mlp(np_inp)
dy_out = out._numpy()
out._backward()
dy_grad = mlp._fc1._w._gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[2, 2], append_batch_size=False)
mlp = MLP()
out = mlp(inp)
param_grads = fluid.backward.append_backward(
out, parameter_list=[mlp._fc1._w.name])[0]
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
fetch_list=[out.name, param_grads[1].name])
self.assertTrue(np.allclose(dy_out, static_out))
self.assertTrue(np.allclose(dy_grad, static_grad))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -29,8 +29,8 @@ class TestMergeSelectedRows(unittest.TestCase): ...@@ -29,8 +29,8 @@ class TestMergeSelectedRows(unittest.TestCase):
def check_with_place(self, place): def check_with_place(self, place):
scope = core.Scope() scope = core.Scope()
x_rows = [0, 5, 5, 4, 20] x_rows = [0, 5, 5, 4, 19]
out_rows = [0, 4, 5, 20] out_rows = [0, 4, 5, 19]
height = 20 height = 20
row_numel = 2 row_numel = 2
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import paddle
import unittest
import six
import numpy as np
dev_cnt = 2
if fluid.core.is_compiled_with_cuda():
dev_cnt = fluid.core.get_cuda_device_count()
os.environ['CPU_NUM'] = str(dev_cnt)
def dummy_func_with_no_input():
return float(1.0)
def dummy_func_with_no_output(x):
pass
def tanh(x):
return np.tanh(x)
def tanh_grad(y, dy):
return np.array(dy) * (1 - np.square(np.array(y)))
def cross_entropy(logits, labels):
logits = np.array(logits)
labels = np.array(labels)
M = logits.shape[0]
N = logits.shape[1]
ret = np.ndarray([M, 1]).astype(logits.dtype)
for idx in six.moves.range(M):
ret[idx][0] = -np.log(logits[idx][labels[idx][0]])
return ret
def cross_entropy_grad(logits, labels, bwd_dout):
logits = np.array(logits)
labels = np.array(labels)
bwd_dout = np.array(bwd_dout)
M = logits.shape[0]
N = logits.shape[1]
dlogits = np.zeros([M, N]).astype(logits.dtype)
for idx in six.moves.range(M):
dlogits[idx][labels[idx][0]] = -bwd_dout[idx] / logits[idx][labels[idx][
0]]
return dlogits, None
def simple_fc_net(img, label, use_py_func_op):
hidden = img
for idx in range(4):
hidden = fluid.layers.fc(
hidden,
size=200,
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=1.0)))
if not use_py_func_op:
hidden = fluid.layers.tanh(hidden)
else:
new_hidden = fluid.default_main_program().current_block(
).create_var(
name='hidden_{}'.format(idx),
dtype='float32',
shape=hidden.shape)
hidden = fluid.layers.py_func(
func=tanh,
x=hidden,
out=new_hidden,
backward_func=tanh_grad,
skip_vars_in_backward_input=hidden)
prediction = fluid.layers.fc(hidden, size=10, act='softmax')
if not use_py_func_op:
loss = fluid.layers.cross_entropy(input=prediction, label=label)
else:
loss = fluid.default_main_program().current_block().create_var(
name='loss', dtype='float32', shape=[-1, 1])
loss = fluid.layers.py_func(
func=cross_entropy,
x=[prediction, label],
out=loss,
backward_func=cross_entropy_grad,
skip_vars_in_backward_input=loss)
dummy_var = fluid.default_main_program().current_block().create_var(
name='test_tmp_var', dtype='float32', shape=[1])
fluid.layers.py_func(
func=dummy_func_with_no_input, x=None, out=dummy_var)
fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None)
loss = fluid.layers.mean(loss)
return loss
def reader():
for _ in six.moves.range(dev_cnt * 100):
yield np.random.random([784]), np.random.random_integers(
size=[1], low=0, high=9)
def test_main(use_cuda, use_py_func_op, use_parallel_executor):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return None
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.core.Scope()):
fluid.default_main_program().random_seed = 1
fluid.default_startup_program().random_seed = 1
np.random.seed(1)
img = fluid.layers.data(name='image', shape=[784], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
loss = simple_fc_net(img, label, use_py_func_op)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
r = paddle.batch(reader, batch_size=10)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if use_parallel_executor:
exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name)
fetch_list = [loss.name]
else:
fetch_list = [loss]
ret = []
for epoch_id in six.moves.range(2):
for d in r():
L, = exe.run(feed=feeder.feed(d), fetch_list=fetch_list)
ret.append(L)
return np.array(ret)
class TestPyFuncOpUseExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = False
def test_loss_diff(self):
losses = []
for use_cuda in [True, False]:
for use_py_func_op in [True, False]:
L = test_main(use_cuda, use_py_func_op,
self.use_parallel_executor)
if L is not None:
losses.append(L)
for idx in six.moves.range(len(losses) - 1):
max_diff = np.max(np.abs(losses[idx] - losses[0]))
self.assertAlmostEqual(max_diff, 0, delta=1e-3)
class TestPyFuncOpUseParallelExecutor(unittest.TestCase):
def setUp(self):
self.use_parallel_executor = True
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_transpose_op import TestTransposeOp
class TestTransposeMKLDNN(TestTransposeOp):
def init_op_type(self):
self.op_type = "transpose2"
self.use_mkldnn = True
self.is_test = True
return
def test_check_grad(self):
return
def test_check_grad_no_input(self):
return
def test_check_grad_no_filter(self):
return
class TestCase0MKLDNN(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (3, )
self.axis = (0, )
class TestCase1a(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (3, 4, 5)
self.axis = (0, 2, 1)
class TestCase1b(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (3, 4, 5)
self.axis = (2, 1, 0)
class TestCase2(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (2, 3, 4, 5)
self.axis = (0, 2, 3, 1)
class TestCase3(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
self.axis = (4, 2, 3, 1, 0)
class TestCase4(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6, 1)
self.axis = (4, 2, 3, 1, 0, 5)
if __name__ == '__main__':
unittest.main()
...@@ -21,15 +21,24 @@ from op_test import OpTest ...@@ -21,15 +21,24 @@ from op_test import OpTest
class TestTransposeOp(OpTest): class TestTransposeOp(OpTest):
def setUp(self): def setUp(self):
self.init_op_type()
self.initTestCase() self.initTestCase()
self.op_type = "transpose2"
self.inputs = {'X': np.random.random(self.shape).astype("float32")} self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis)} self.attrs = {
'axis': list(self.axis),
'use_mkldnn': self.use_mkldnn,
'is_test': self.is_test,
}
self.outputs = { self.outputs = {
'XShape': np.random.random(self.shape).astype("float32"), 'XShape': np.random.random(self.shape).astype("float32"),
'Out': self.inputs['X'].transpose(self.axis) 'Out': self.inputs['X'].transpose(self.axis)
} }
def init_op_type(self):
self.op_type = "transpose2"
self.use_mkldnn = False
self.is_test = False
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=['XShape']) self.check_output(no_check_set=['XShape'])
......
...@@ -107,9 +107,9 @@ packages=['paddle', ...@@ -107,9 +107,9 @@ packages=['paddle',
'paddle.fluid.distributed', 'paddle.fluid.distributed',
'paddle.fluid.layers', 'paddle.fluid.layers',
'paddle.fluid.contrib', 'paddle.fluid.contrib',
'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.decoder', 'paddle.fluid.contrib.decoder',
'paddle.fluid.contrib.quantize', 'paddle.fluid.contrib.quantize',
'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler', 'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details'] 'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册