提交 c7c81fe0 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_temporal_shift_to_phi

...@@ -56,23 +56,29 @@ static std::string LegalizeVariableName(const std::string& var_name) { ...@@ -56,23 +56,29 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret; return ret;
} }
static bool IgnoreGradAttribute(const std::string& op_type, static std::string HandleDynamicGradAttributes(const std::string& fwd_op_type,
const std::string& attr_name) { const std::string& attrs_name) {
// Attributes in operators_with_attrs are created manually during code std::string additional_grad_attrs_str = "";
// generation
// We should ignore these arbitrary attrs when setting up grad attribute map
if (operators_with_attrs.count(op_type)) {
if (operators_with_attrs[op_type].count(attr_name)) {
return true;
}
}
// Only allow SumOp if (fwd_op_type == "sum") {
if (op_type != "sum") { const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n";
return true; additional_grad_attrs_str = paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "scale", "float(1.0)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)");
} else if (fwd_op_type == "scale") {
const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n";
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)");
} }
return false; return additional_grad_attrs_str;
} }
static void PrepareAttrMapForOps() { static void PrepareAttrMapForOps() {
...@@ -1866,18 +1872,9 @@ static std::string GenerateSingleOpBase( ...@@ -1866,18 +1872,9 @@ static std::string GenerateSingleOpBase(
const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n"; const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n";
std::string grad_attrs_str = std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name); paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
for (const auto& iter : grad_attrs) {
if (IgnoreGradAttribute(fwd_op_type, iter.first)) continue; // Handle dynamic grad attributes
std::pair<std::string, std::string> type_val = grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name);
GetAttrType(iter.second, false /*is_arg*/);
const char* GRAD_ATTRS_TEMPLATE =
" %s %s = %s;\n"
" %s[\"%s\"] = %s;\n";
std::string var_name = iter.first + std::to_string(*outs_size);
grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, type_val.first, var_name, type_val.second,
attrs_name, iter.first, var_name);
}
generated_grad_function_body += grad_attrs_str; generated_grad_function_body += grad_attrs_str;
const char* TRACE_OP_TEMPLATE = const char* TRACE_OP_TEMPLATE =
......
...@@ -28,6 +28,7 @@ namespace = "" ...@@ -28,6 +28,7 @@ namespace = ""
yaml_types_mapping = { yaml_types_mapping = {
'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \ 'int' : 'int', 'int32' : 'int32_t', 'int64' : 'int64_t', 'size_t' : 'size_t', \
'float' : 'float', 'double' : 'double', 'bool' : 'bool', \ 'float' : 'float', 'double' : 'double', 'bool' : 'bool', \
'str' : 'std::string', \
'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \ 'Backend' : 'paddle::experimental::Backend', 'DataLayout' : 'paddle::experimental::DataLayout', 'DataType' : 'paddle::experimental::DataType', \
'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>', 'int64[]' : 'std::vector<int64_t>', 'int[]' : 'std::vector<int>',
'Tensor' : 'Tensor', 'Tensor' : 'Tensor',
...@@ -212,7 +213,8 @@ def ParseYamlArgs(string): ...@@ -212,7 +213,8 @@ def ParseYamlArgs(string):
default_value = m.group(3).split("=")[1].strip() if len( default_value = m.group(3).split("=")[1].strip() if len(
m.group(3).split("=")) > 1 else None m.group(3).split("=")) > 1 else None
assert arg_type in yaml_types_mapping.keys() assert arg_type in yaml_types_mapping.keys(
), f"The argument type {arg_type} in yaml config is not supported in yaml_types_mapping."
arg_type = yaml_types_mapping[arg_type] arg_type = yaml_types_mapping[arg_type]
arg_name = RemoveSpecialSymbolsInName(arg_name) arg_name = RemoveSpecialSymbolsInName(arg_name)
...@@ -247,7 +249,8 @@ def ParseYamlReturns(string): ...@@ -247,7 +249,8 @@ def ParseYamlReturns(string):
else: else:
ret_type = ret.strip() ret_type = ret.strip()
assert ret_type in yaml_types_mapping.keys() assert ret_type in yaml_types_mapping.keys(
), f"The return type {ret_type} in yaml config is not supported in yaml_types_mapping."
ret_type = yaml_types_mapping[ret_type] ret_type = yaml_types_mapping[ret_type]
assert "Tensor" in ret_type assert "Tensor" in ret_type
...@@ -1245,7 +1248,7 @@ if __name__ == "__main__": ...@@ -1245,7 +1248,7 @@ if __name__ == "__main__":
# Node Definition Generation # Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition( definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, orig_forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs, backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs) intermediate_outputs)
...@@ -1257,7 +1260,7 @@ if __name__ == "__main__": ...@@ -1257,7 +1260,7 @@ if __name__ == "__main__":
# For python-level API dispatch # For python-level API dispatch
CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map, CollectCoreOpsInformation(fwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_outputs_position_map,
forward_attrs_list) orig_forward_attrs_list)
if len(namespace) > 0: if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{ forward_definition_str += f"""namespace {namespace} {{
......
...@@ -24,7 +24,7 @@ atype_to_parsing_function = { ...@@ -24,7 +24,7 @@ atype_to_parsing_function = {
"long": "CastPyArg2Long", "long": "CastPyArg2Long",
"int64_t": "CastPyArg2Long", "int64_t": "CastPyArg2Long",
"float": "CastPyArg2Float", "float": "CastPyArg2Float",
"string": "CastPyArg2String", "std::string": "CastPyArg2String",
"std::vector<bool>": "CastPyArg2Booleans", "std::vector<bool>": "CastPyArg2Booleans",
"std::vector<int>": "CastPyArg2Ints", "std::vector<int>": "CastPyArg2Ints",
"std::vector<long>": "CastPyArg2Longs", "std::vector<long>": "CastPyArg2Longs",
......
...@@ -34,6 +34,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place, ...@@ -34,6 +34,14 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
return; return;
} }
// NOTE(hqp): Special case for CPU->MLU, avoid stream sync.
if (platform::is_cpu_place(in.place()) && platform::is_mlu_place(dst_place)) {
paddle::framework::TensorCopy(
in, dst_place, *platform::DeviceContextPool::Instance().Get(dst_place),
out);
return;
}
// NOTE(yy): TransDataDevice should wait for computation of input. // NOTE(yy): TransDataDevice should wait for computation of input.
if (!platform::is_cuda_pinned_place(in.place())) { if (!platform::is_cuda_pinned_place(in.place())) {
platform::DeviceContextPool::Instance().Get(in.place())->Wait(); platform::DeviceContextPool::Instance().Get(in.place())->Wait();
......
...@@ -174,10 +174,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -174,10 +174,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc, bool keep_kid_scopes) { bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
platform::RegisterModelLayout(ctx->ops_, place_);
#endif #endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars, RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes); keep_kid_scopes);
} }
......
...@@ -148,7 +148,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -148,7 +148,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
t.join(); t.join();
} }
timeline.Pause(); timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
} else { } else {
CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos); CHECK(data_set_name.find("MultiSlotDataset") != std::string::npos);
VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset"; VLOG(0) << "ps_gpu_wrapper use MultiSlotDataset";
...@@ -182,7 +182,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) { ...@@ -182,7 +182,7 @@ void PSGPUWrapper::PreBuildTask(std::shared_ptr<HeterContext> gpu_task) {
t.join(); t.join();
} }
timeline.Pause(); timeline.Pause();
VLOG(1) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds."; VLOG(0) << "GpuPs build task cost " << timeline.ElapsedSec() << " seconds.";
} }
timeline.Start(); timeline.Start();
...@@ -300,7 +300,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -300,7 +300,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
int32_t cnt = 0; int32_t cnt = 0;
while (true) { while (true) {
auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_, i, reinterpret_cast<char**>(local_ptr[i].data()), this->table_id_,
local_keys[i].data(), key_size); local_keys[i].data(), key_size);
bool flag = true; bool flag = true;
...@@ -378,8 +378,8 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -378,8 +378,8 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
int32_t cnt = 0; int32_t cnt = 0;
while (true) { while (true) {
auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr( auto tt = fleet_ptr->pslib_ptr_->_worker_ptr->pull_sparse_ptr(
reinterpret_cast<char**>(local_dim_ptr[i][j].data()), this->table_id_, i, reinterpret_cast<char**>(local_dim_ptr[i][j].data()),
local_dim_keys[i][j].data(), key_size); this->table_id_, local_dim_keys[i][j].data(), key_size);
bool flag = true; bool flag = true;
tt.wait(); tt.wait();
...@@ -431,7 +431,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -431,7 +431,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
t.join(); t.join();
} }
timeline.Pause(); timeline.Pause();
VLOG(1) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec() VLOG(0) << "pull sparse from CpuPS into GpuPS cost " << timeline.ElapsedSec()
<< " seconds."; << " seconds.";
if (multi_node_) { if (multi_node_) {
auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance(); auto gloo_wrapper = paddle::framework::GlooWrapper::GetInstance();
...@@ -603,7 +603,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) { ...@@ -603,7 +603,7 @@ void PSGPUWrapper::BuildPull(std::shared_ptr<HeterContext> gpu_task) {
t.join(); t.join();
} }
timeline.Pause(); timeline.Pause();
VLOG(1) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec() VLOG(0) << "GpuPs prepare for build hbm cost " << timeline.ElapsedSec()
<< " seconds."; << " seconds.";
} }
...@@ -746,7 +746,7 @@ void PSGPUWrapper::BeginPass() { ...@@ -746,7 +746,7 @@ void PSGPUWrapper::BeginPass() {
"[BeginPass] after build_task, current task is not null.")); "[BeginPass] after build_task, current task is not null."));
} }
VLOG(1) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s"; VLOG(0) << "BeginPass end, cost time: " << timer.ElapsedSec() << "s";
} }
void PSGPUWrapper::EndPass() { void PSGPUWrapper::EndPass() {
...@@ -769,7 +769,7 @@ void PSGPUWrapper::EndPass() { ...@@ -769,7 +769,7 @@ void PSGPUWrapper::EndPass() {
current_task_ = nullptr; current_task_ = nullptr;
gpu_free_channel_->Put(current_task_); gpu_free_channel_->Put(current_task_);
timer.Pause(); timer.Pause();
VLOG(1) << "EndPass end, cost time: " << timer.ElapsedSec() << "s"; VLOG(0) << "EndPass end, cost time: " << timer.ElapsedSec() << "s";
} }
void PSGPUWrapper::PullSparse(const paddle::platform::Place& place, void PSGPUWrapper::PullSparse(const paddle::platform::Place& place,
......
...@@ -95,6 +95,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock( ...@@ -95,6 +95,7 @@ std::map<std::string, std::vector<ir::Node *>> Graph::InitFromBlock(
std::unordered_map<std::string, std::pair<VarDesc *, int>> std::unordered_map<std::string, std::pair<VarDesc *, int>>
name_to_desc_block_id; name_to_desc_block_id;
block_id_ = block.ID();
const BlockDesc *block_var_visible = &block; const BlockDesc *block_var_visible = &block;
while (block_var_visible != nullptr) { while (block_var_visible != nullptr) {
for (auto *var : block_var_visible->AllVars()) { for (auto *var : block_var_visible->AllVars()) {
......
...@@ -230,6 +230,7 @@ class Graph { ...@@ -230,6 +230,7 @@ class Graph {
auto *x = auto *x =
AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id)); AddNode(new ir::Node(var_desc, block_id == -1 ? block_id_ : block_id));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
...@@ -245,6 +246,7 @@ class Graph { ...@@ -245,6 +246,7 @@ class Graph {
"The OpDesc used to create operator node is null.")); "The OpDesc used to create operator node is null."));
auto *x = AddNode(new ir::Node(op_desc)); auto *x = AddNode(new ir::Node(op_desc));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
...@@ -263,6 +265,7 @@ class Graph { ...@@ -263,6 +265,7 @@ class Graph {
num_node_created_); num_node_created_);
auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_)); auto *x = AddNode(new ir::Node(name, ir::Node::Type::kVariable, block_id_));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
...@@ -276,6 +279,7 @@ class Graph { ...@@ -276,6 +279,7 @@ class Graph {
} }
auto *x = AddNode(new ir::Node(name, type, block_id_)); auto *x = AddNode(new ir::Node(name, type, block_id_));
x->SetId(num_node_created_++); x->SetId(num_node_created_++);
x->SetGraphId(block_id_);
return x; return x;
} }
......
...@@ -2052,18 +2052,19 @@ PDNode *patterns::Pool::operator()() { ...@@ -2052,18 +2052,19 @@ PDNode *patterns::Pool::operator()() {
return output_var; return output_var;
} }
PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { PDNode *patterns::Elementwise::operator()(PDNode *x_var, PDNode *y_var,
auto elementwise_add_op = pattern->NewNode(elementwise_add_op_repr()) const std::string elementwise_type) {
->assert_is_op("elementwise_add"); auto elementwise_op =
pattern->NewNode(elementwise_op_repr())->assert_is_op(elementwise_type);
x_var->AsInput()->assert_is_op_input("elementwise_add", "X");
y_var->AsInput()->assert_is_op_input("elementwise_add", "Y"); x_var->AsInput()->assert_is_op_input(elementwise_type, "X");
auto out_var = pattern->NewNode(elementwise_add_out_repr()) y_var->AsInput()->assert_is_op_input(elementwise_type, "Y");
auto out_var = pattern->NewNode(elementwise_out_repr())
->AsOutput() ->AsOutput()
->assert_is_op_output("elementwise_add", "Out"); ->assert_is_op_output(elementwise_type, "Out");
elementwise_add_op->LinksFrom({x_var, y_var}); elementwise_op->LinksFrom({x_var, y_var});
elementwise_add_op->LinksTo({out_var}); elementwise_op->LinksTo({out_var});
return out_var; return out_var;
} }
......
...@@ -1016,20 +1016,20 @@ struct Pool : public PatternBase { ...@@ -1016,20 +1016,20 @@ struct Pool : public PatternBase {
PATTERN_DECL_NODE(pool_output); PATTERN_DECL_NODE(pool_output);
}; };
// ElementwiseAdd used in residual connections. // Elementwise ops
// y_var is used and convolution output. // Forward pass for element-wise operators (add, mul)
// The operator is removed, when residual // elementwise_mul_out is the result of the operator
// connection fusion is on. struct Elementwise : public PatternBase {
struct ElementwiseAdd : public PatternBase { Elementwise(PDPattern* pattern, const std::string& name_scope)
ElementwiseAdd(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "elementwise") {}
: PatternBase(pattern, name_scope, "elementwise_add") {}
PDNode* operator()(PDNode* x_var, PDNode* y_var,
PDNode* operator()(PDNode* x_var, PDNode* y_var); const std::string elementwise_type);
PATTERN_DECL_NODE(elementwise_add_op); PATTERN_DECL_NODE(elementwise_op);
PATTERN_DECL_NODE(elementwise_add_x); PATTERN_DECL_NODE(elementwise_x);
PATTERN_DECL_NODE(elementwise_add_y); PATTERN_DECL_NODE(elementwise_y);
PATTERN_DECL_NODE(elementwise_add_out); PATTERN_DECL_NODE(elementwise_out);
}; };
// Transpose op // Transpose op
......
...@@ -118,7 +118,7 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { ...@@ -118,7 +118,7 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"}) .IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
...@@ -145,10 +145,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -145,10 +145,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
patterns::Conv conv_pattern{pattern, name_scope}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_pattern(
conv_output, conv_output, pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); "elementwise_add");
conv_output->AsIntermediate(); conv_output->AsIntermediate();
int found_conv_as_x_count = 0; int found_conv_as_x_count = 0;
...@@ -160,16 +160,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -160,16 +160,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_identity, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_identity, elementwise_y,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_identity, conv_output)) return; if (!IsReachable(g, elementwise_identity, conv_output)) return;
if (HasFusedActivation(conv_op)) return; if (HasFusedActivation(conv_op)) return;
...@@ -179,14 +179,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( ...@@ -179,14 +179,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
return; return;
} }
conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_identity->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_add_identity, conv_op); IR_NODE_LINK_TO(elementwise_identity, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out); IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_x_count++; found_conv_as_x_count++;
}; };
...@@ -212,10 +212,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -212,10 +212,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
patterns::Conv conv_pattern{pattern, name_scope}; patterns::Conv conv_pattern{pattern, name_scope};
auto conv_output = conv_pattern(); auto conv_output = conv_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern( elementwise_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), pattern->NewNode(elementwise_pattern.elementwise_x_repr()), conv_output,
conv_output); "elementwise_add");
conv_output->AsIntermediate(); conv_output->AsIntermediate();
int found_conv_as_y_count = 0; int found_conv_as_y_count = 0;
...@@ -227,16 +227,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -227,16 +227,16 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_filter, conv_filter, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_output, conv_output, conv_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (FindFuseOption(*conv_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, elementwise_add_x, conv_output)) return; if (!IsReachable(g, elementwise_x, conv_output)) return;
if (HasFusedActivation(conv_op)) return; if (HasFusedActivation(conv_op)) return;
...@@ -246,14 +246,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( ...@@ -246,14 +246,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
return; return;
} }
conv_op->Op()->SetInput("ResidualData", {elementwise_add_x->Name()}); conv_op->Op()->SetInput("ResidualData", {elementwise_x->Name()});
conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
conv_op->Op()->SetAttr("fuse_residual_connection", true); conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {conv_output, elementwise_op});
IR_NODE_LINK_TO(elementwise_add_x, conv_op); IR_NODE_LINK_TO(elementwise_x, conv_op);
IR_NODE_LINK_TO(conv_op, elementwise_add_out); IR_NODE_LINK_TO(conv_op, elementwise_out);
found_conv_as_y_count++; found_conv_as_y_count++;
}; };
...@@ -282,8 +282,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -282,8 +282,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
patterns::Conv conv_y_pattern{pattern, name_scope}; patterns::Conv conv_y_pattern{pattern, name_scope};
auto conv_y_output = conv_y_pattern(); auto conv_y_output = conv_y_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope}; patterns::Elementwise elementwise_pattern{pattern, name_scope};
elementwise_add_pattern(conv_x_output, conv_y_output); elementwise_pattern(conv_x_output, conv_y_output, "elementwise_add");
conv_x_output->AsIntermediate(); conv_x_output->AsIntermediate();
conv_y_output->AsIntermediate(); conv_y_output->AsIntermediate();
...@@ -301,10 +301,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -301,10 +301,10 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_y_filter, conv_filter, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern); GET_IR_NODE_FROM_SUBGRAPH(conv_y_output, conv_output, conv_y_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) LOG(WARNING)
...@@ -312,8 +312,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -312,8 +312,8 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
return; return;
} }
if (FindFuseOption(*conv_x_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_x_op, *elementwise_op) != FUSE_MKLDNN) return;
if (FindFuseOption(*conv_y_op, *elementwise_add_op) != FUSE_MKLDNN) return; if (FindFuseOption(*conv_y_op, *elementwise_op) != FUSE_MKLDNN) return;
Node* projection_node; Node* projection_node;
Node* residual_conv_op; Node* residual_conv_op;
...@@ -333,14 +333,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( ...@@ -333,14 +333,14 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
if (HasFusedActivation(residual_conv_op)) return; if (HasFusedActivation(residual_conv_op)) return;
residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()});
residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_out->Name()});
residual_conv_op->Op()->SetAttr("fuse_residual_connection", true); residual_conv_op->Op()->SetAttr("fuse_residual_connection", true);
GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_add_op}); GraphSafeRemoveNodes(g, {residual_conv_output, elementwise_op});
IR_NODE_LINK_TO(projection_node, residual_conv_op); IR_NODE_LINK_TO(projection_node, residual_conv_op);
IR_NODE_LINK_TO(residual_conv_op, elementwise_add_out); IR_NODE_LINK_TO(residual_conv_op, elementwise_out);
found_projection_conv_count++; found_projection_conv_count++;
}; };
......
...@@ -807,74 +807,74 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { ...@@ -807,74 +807,74 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count); PrettyLogDetail("--- quantized %d matmul ops", quantize_matmul_count);
} }
void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const { void CPUQuantizePass::QuantizeElementwise(
Graph* graph, const std::string elementwise_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::ElementwiseAdd elementwise_add_pattern{pattern, name_scope_}; patterns::Elementwise elementwise_pattern{pattern, name_scope_};
elementwise_add_pattern( elementwise_pattern(
pattern->NewNode(elementwise_add_pattern.elementwise_add_x_repr()), pattern->NewNode(elementwise_pattern.elementwise_x_repr()),
pattern->NewNode(elementwise_add_pattern.elementwise_add_y_repr())); pattern->NewNode(elementwise_pattern.elementwise_y_repr()),
elementwise_type);
int quantize_elementwise_add_count = 0; int quantize_elementwise_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "Quantize elementwise_add op"; VLOG(4) << "Quantize " + elementwise_type + " op";
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_op, elementwise_add_op, GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_add_pattern); elementwise_pattern);
// skip if should not be quantized // skip if should not be quantized
if (!platform::HasOpINT8DataType(elementwise_add_op->Op())) { if (!platform::HasOpINT8DataType(elementwise_op->Op())) {
LogQuantizationDisabled(elementwise_add_op); LogQuantizationDisabled(elementwise_op);
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_x, elementwise_add_x, GET_IR_NODE_FROM_SUBGRAPH(elementwise_x, elementwise_x,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_y, elementwise_add_y, GET_IR_NODE_FROM_SUBGRAPH(elementwise_y, elementwise_y,
elementwise_add_pattern); elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add_out, elementwise_add_out, GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_add_pattern); elementwise_pattern);
if (!AreScalesPresentForNodes( if (!AreScalesPresentForNodes(
{elementwise_add_x, elementwise_add_y, elementwise_add_out})) { {elementwise_x, elementwise_y, elementwise_out})) {
LogCannotQuantizeOp(elementwise_add_op, LogCannotQuantizeOp(elementwise_op,
"No scale available for the operator"); "No scale available for the operator");
return; return;
} }
bool is_x_unsigned{false}, is_y_unsigned{false}; bool is_x_unsigned{false}, is_y_unsigned{false};
auto input_x_scale = auto input_x_scale = GetScaleValueForNode(elementwise_x, &is_x_unsigned);
GetScaleValueForNode(elementwise_add_x, &is_x_unsigned); auto input_y_scale = GetScaleValueForNode(elementwise_y, &is_y_unsigned);
auto input_y_scale =
GetScaleValueForNode(elementwise_add_y, &is_y_unsigned);
// TODO(sfraczek): add support for different signness // TODO(sfraczek): add support for different signness
if (is_x_unsigned != is_y_unsigned) { if (is_x_unsigned != is_y_unsigned) {
LogCannotQuantizeOp(elementwise_add_op, LogCannotQuantizeOp(elementwise_op,
"ElementwiseAdd inputs must be of the same type."); "Elementwise inputs must be of the same type.");
return; return;
} }
QuantizeInput(g, elementwise_add_op, elementwise_add_x, "X", input_x_scale, QuantizeInput(g, elementwise_op, elementwise_x, "X", input_x_scale,
is_x_unsigned, "Scale_x"); is_x_unsigned, "Scale_x");
QuantizeInput(g, elementwise_add_op, elementwise_add_y, "Y", input_y_scale, QuantizeInput(g, elementwise_op, elementwise_y, "Y", input_y_scale,
is_y_unsigned, "Scale_y"); is_y_unsigned, "Scale_y");
bool is_output_unsigned{false}; bool is_output_unsigned{false};
auto output_scale = auto output_scale =
GetScaleValueForNode(elementwise_add_out, &is_output_unsigned); GetScaleValueForNode(elementwise_out, &is_output_unsigned);
DequantizeOutput(g, elementwise_add_op, elementwise_add_out, "Out", DequantizeOutput(g, elementwise_op, elementwise_out, "Out", output_scale,
output_scale, is_output_unsigned, "Scale_out"); is_output_unsigned, "Scale_out");
++quantize_elementwise_add_count; ++quantize_elementwise_count;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(quantize_elementwise_add_count); AddStatis(quantize_elementwise_count);
PrettyLogDetail("--- quantized %d elementwise_add ops", PrettyLogDetail("--- quantized %d %s ops", quantize_elementwise_count,
quantize_elementwise_add_count); elementwise_type);
} }
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const { void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
...@@ -1146,7 +1146,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1146,7 +1146,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeFc(graph); QuantizeFc(graph);
QuantizeReshape(graph); QuantizeReshape(graph);
QuantizeMatmul(graph); QuantizeMatmul(graph);
QuantizeElementwiseAdd(graph); QuantizeElementwise(graph, "elementwise_add");
QuantizeElementwise(graph, "elementwise_mul");
QuantizeFusionGru(graph); QuantizeFusionGru(graph);
QuantizeMultiGru(graph); QuantizeMultiGru(graph);
QuantizeFusionLSTM(graph); QuantizeFusionLSTM(graph);
......
...@@ -57,7 +57,8 @@ class CPUQuantizePass : public FusePassBase { ...@@ -57,7 +57,8 @@ class CPUQuantizePass : public FusePassBase {
void QuantizeTranspose(Graph* graph) const; void QuantizeTranspose(Graph* graph) const;
void QuantizeReshape(Graph* graph) const; void QuantizeReshape(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const; void QuantizeMatmul(Graph* graph) const;
void QuantizeElementwiseAdd(Graph* graph) const; void QuantizeElementwise(Graph* graph,
const std::string elementwise_type) const;
void QuantizeFusionGru(Graph* graph) const; void QuantizeFusionGru(Graph* graph) const;
void QuantizeMultiGru(Graph* graph) const; void QuantizeMultiGru(Graph* graph) const;
void QuantizeFusionLSTM(Graph* graph) const; void QuantizeFusionLSTM(Graph* graph) const;
......
...@@ -90,7 +90,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -90,7 +90,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f); op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f); op->SetAttr("Scale_out", 1.0f);
} else if (type == "elementwise_add") { } else if (type == "elementwise_add" || type == "elementwise_mul") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
...@@ -167,7 +167,8 @@ void CheckScales(const OpDesc* op, float scale, float shift) { ...@@ -167,7 +167,8 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale); scale);
scale_names.push_back("Scale_in"); scale_names.push_back("Scale_in");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
} else if (type == "matmul" || type == "elementwise_add") { } else if (type == "matmul" || type == "elementwise_add" ||
type == "elementwise_mul") {
scale_names.push_back("Scale_x"); scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y"); scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out"); scale_names.push_back("Scale_out");
...@@ -546,46 +547,77 @@ TEST(CpuQuantizePass, matmul_not_quantized) { ...@@ -546,46 +547,77 @@ TEST(CpuQuantizePass, matmul_not_quantized) {
expected_operators, added_nodes, 1.0f); expected_operators, added_nodes, 1.0f);
} }
static const std::initializer_list<std::string> variable_names_elementwise_add = static const std::initializer_list<std::string> variable_names_elementwise = {
{"a", "b", "c", "d", "e", "f"}; "a", "b", "c", "d", "e", "f"};
ProgramDesc BuildProgramDescElementwiseAdd() { ProgramDesc BuildProgramDescElementwise(const std::string elementwise_type,
const std::string elementwise_name) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : variable_names_elementwise_add) { for (auto& v : variable_names_elementwise) {
prog.MutableBlock(0)->Var(v); prog.MutableBlock(0)->Var(v);
} }
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true); SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "elementwise_add", "ElementwiseAdd", {"b", "d"}, {"e"}, true, SetOp(&prog, elementwise_type, elementwise_name, {"b", "d"}, {"e"}, true,
"int8"); "int8");
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32"); SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32");
return prog; return prog;
} }
TEST(CpuQuantizePass, elementwise_add) { void TestElementwise(const std::string elementwise_type,
const std::string elementwise_name) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT // 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6; int added_nodes = 6;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 2}, {"dequantize", 3}}; {elementwise_type, 1}, {"quantize", 2}, {"dequantize", 3}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, SCALE * S8_MAX); variable_names_elementwise, expected_operators, added_nodes,
SCALE * S8_MAX);
} }
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) { void TestElementwiseOutputScaleMissing(const std::string elementwise_type,
const std::string elementwise_name) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, 1.f, 1.f, "e"); variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "e");
} }
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) { void TestElementwiseUnsignedAndSignedInput(const std::string elementwise_type,
const std::string elementwise_name) {
int added_nodes = 0; int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = { std::unordered_map<std::string, int> expected_operators = {
{"elementwise_add", 1}, {"quantize", 0}, {"dequantize", 2}}; {elementwise_type, 1}, {"quantize", 0}, {"dequantize", 2}};
MainTest(BuildProgramDescElementwiseAdd(), variable_names_elementwise_add, MainTest(BuildProgramDescElementwise(elementwise_type, elementwise_name),
expected_operators, added_nodes, 1.f, 1.f, "", "b"); variable_names_elementwise, expected_operators, added_nodes, 1.f,
1.f, "", "b");
}
TEST(CpuQuantizePass, elementwise_add) {
TestElementwise("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_add_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_add_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_add", "ElementwiseAdd");
}
TEST(CpuQuantizePass, elementwise_mul) {
TestElementwise("elementwise_mul", "ElementwiseMul");
}
TEST(CpuQuantizePass, elementwise_mul_output_scale_missing) {
TestElementwiseOutputScaleMissing("elementwise_mul", "ElementwiseMul");
}
TEST(CpuQuantizePass, elementwise_mul_unsigned_and_signed_input) {
TestElementwiseUnsignedAndSignedInput("elementwise_mul", "ElementwiseMul");
} }
const std::vector<std::string> churn_out_vars(ProgramDesc* prog, const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
......
...@@ -26,10 +26,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -26,10 +26,10 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Marks operators which are to be quantized."; VLOG(3) << "Marks operators which are to be quantized.";
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>(
{"concat", "conv2d", "depthwise_conv2d", "elementwise_add", "fc", {"concat", "conv2d", "depthwise_conv2d", "elementwise_add",
"matmul", "nearest_interp", "nearest_interp_v2", "pool2d", "elementwise_mul", "fc", "matmul", "nearest_interp",
"prior_box", "reshape2", "transpose2", "fusion_gru", "fusion_lstm", "nearest_interp_v2", "pool2d", "prior_box", "reshape2", "transpose2",
"multi_gru", "slice"}); "fusion_gru", "fusion_lstm", "multi_gru", "slice"});
const auto& excluded_ids_list = const auto& excluded_ids_list =
Get<std::unordered_set<int>>("quantize_excluded_op_ids"); Get<std::unordered_set<int>>("quantize_excluded_op_ids");
const auto& op_types_list = const auto& op_types_list =
......
...@@ -125,6 +125,7 @@ class Node { ...@@ -125,6 +125,7 @@ class Node {
// Only use this for auto parallel. // Only use this for auto parallel.
// A node does not have original desc if the return is zero. // A node does not have original desc if the return is zero.
uint64_t OriginalDescId() const { return original_desc_id_; } uint64_t OriginalDescId() const { return original_desc_id_; }
int GraphId() const { return graph_id_; }
bool IsOp() const { return type_ == Type::kOperation; } bool IsOp() const { return type_ == Type::kOperation; }
bool IsVar() const { return type_ == Type::kVariable; } bool IsVar() const { return type_ == Type::kVariable; }
...@@ -246,10 +247,12 @@ class Node { ...@@ -246,10 +247,12 @@ class Node {
// Store the original id of var desc or op desc. // Store the original id of var desc or op desc.
// Only use this for auto parallel. // Only use this for auto parallel.
uint64_t original_desc_id_{0}; uint64_t original_desc_id_{0};
int graph_id_{-1};
private: private:
// ID can only set by a Graph. // ID can only set by a Graph.
void SetId(int id) { id_ = id; } void SetId(int id) { id_ = id; }
void SetGraphId(int graph_id) { graph_id_ = graph_id; }
// desc_order can only set by a Graph when constructing a Graph from a // desc_order can only set by a Graph when constructing a Graph from a
// BlockDesc. // BlockDesc.
......
...@@ -41,6 +41,7 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc, ...@@ -41,6 +41,7 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
void NaiveExecutor::Run() { void NaiveExecutor::Run() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
platform::RegisterModelLayout(ops_, place_);
#endif #endif
platform::ScopedFlushDenormal flush; platform::ScopedFlushDenormal flush;
for (auto &op : ops_) { for (auto &op : ops_) {
......
...@@ -1456,7 +1456,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { ...@@ -1456,7 +1456,8 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (platform::is_xpu_place(expected_kernel_key.place_) && if (platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() || (kernel_iter == kernels.end() ||
!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) || !paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
...@@ -1470,18 +1471,37 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const { ...@@ -1470,18 +1471,37 @@ void OperatorWithKernel::ChooseKernel(const ExecutionContext& ctx) const {
#endif #endif
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt = bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel && FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key); paddle::platform::is_xpu_kp_support_op(type_, expected_kernel_key);
bool use_xpu_kp_kernel_debug = bool use_xpu_kp_kernel_debug =
paddle::platform::is_in_xpu_kpwhite_list(type_); paddle::platform::is_in_xpu_kpwhite_list(type_);
if (platform::is_xpu_place(expected_kernel_key.place_) && if (use_xpu_kp_kernel_rt) {
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) { VLOG(3) << "xpu_kp using rt mode ";
}
if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode ";
}
bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
if (is_xpu_kp_support) {
expected_kernel_key.library_type_ = LibraryType::kKP; expected_kernel_key.library_type_ = LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << type_ VLOG(3) << "using XPU KP kernel: " << type_
<< ", using_kernel_key:" << expected_kernel_key; << ", using_kernel_key:" << expected_kernel_key;
} }
bool is_xpu_unsupport =
(!paddle::platform::is_xpu_support_op(type_, expected_kernel_key) ||
paddle::platform::is_in_xpu_black_list(type_));
if (!is_xpu_kp_support &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << type_
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
}
#endif #endif
#ifdef PADDLE_WITH_IPU #ifdef PADDLE_WITH_IPU
......
...@@ -1224,8 +1224,12 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -1224,8 +1224,12 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
proto::VarType::TensorDesc desc; proto::VarType::TensorDesc desc;
{ // int32_t size { // int32_t size
// proto buffer // proto buffer
int32_t size; int32_t size = -1;
is.read(reinterpret_cast<char*>(&size), sizeof(size)); is.read(reinterpret_cast<char*>(&size), sizeof(size));
PADDLE_ENFORCE_EQ(is.good(), true, platform::errors::Unavailable(
"Cannot read tensor desc size"));
PADDLE_ENFORCE_GE(size, 0, platform::errors::InvalidArgument(
"Tensor desc size should >= 0"));
std::unique_ptr<char[]> buf(new char[size]); std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char*>(buf.get()), size); is.read(reinterpret_cast<char*>(buf.get()), size);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -124,7 +124,7 @@ AmpOperators::AmpOperators() ...@@ -124,7 +124,7 @@ AmpOperators::AmpOperators()
OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16)); OpSupportedInfos("GPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(), unsupported_bf16_ops_->insert(unsupported_ops_gpu_bf16.begin(),
unsupported_ops_gpu_bf16.end()); unsupported_ops_gpu_bf16.end());
// NOTE: GPU/NPU/XPU is compiled seperatly. // NOTE: GPU/NPU/XPU/MLU is compiled seperatly.
#elif defined(PADDLE_WITH_ASCEND_CL) #elif defined(PADDLE_WITH_ASCEND_CL)
auto unsupported_ops_npu_fp16 = std::get<2>( auto unsupported_ops_npu_fp16 = std::get<2>(
OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16)); OpSupportedInfos("NPU", paddle::framework::proto::VarType::FP16));
...@@ -143,6 +143,15 @@ AmpOperators::AmpOperators() ...@@ -143,6 +143,15 @@ AmpOperators::AmpOperators()
OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16)); OpSupportedInfos("XPU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(), unsupported_bf16_ops_->insert(unsupported_ops_xpu_bf16.begin(),
unsupported_ops_xpu_bf16.end()); unsupported_ops_xpu_bf16.end());
#elif defined(PADDLE_WITH_MLU)
auto unsupported_ops_mlu_fp16 = std::get<2>(
OpSupportedInfos("MLU", paddle::framework::proto::VarType::FP16));
unsupported_fp16_ops_->insert(unsupported_ops_mlu_fp16.begin(),
unsupported_ops_mlu_fp16.end());
auto unsupported_ops_mlu_bf16 = std::get<2>(
OpSupportedInfos("MLU", paddle::framework::proto::VarType::BF16));
unsupported_bf16_ops_->insert(unsupported_ops_mlu_bf16.begin(),
unsupported_ops_mlu_bf16.end());
#endif #endif
VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " " VLOG(4) << allow_ops_->size() << " " << block_ops_->size() << " "
<< unsupported_fp16_ops_->size() << " " << unsupported_fp16_ops_->size() << " "
...@@ -209,7 +218,10 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) { ...@@ -209,7 +218,10 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
auto data_type = GetDataType<VarType>(var); auto data_type = GetDataType<VarType>(var);
if (paddle::platform::is_gpu_place(place) || if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) || paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place)) { paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
if (data_type == paddle::framework::proto::VarType::FP32 || if (data_type == paddle::framework::proto::VarType::FP32 ||
data_type == paddle::framework::proto::VarType::FP16 || data_type == paddle::framework::proto::VarType::FP16 ||
......
...@@ -234,7 +234,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -234,7 +234,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
auto& kernels = kernels_iter->second; auto& kernels = kernels_iter->second;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_XPU #if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) { (kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << op.Type() VLOG(3) << "missing XPU kernel: " << op.Type()
...@@ -243,11 +243,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -243,11 +243,10 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
expected_kernel_key.place_ = platform::XPUPlace(); if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
bool use_xpu_kp_kernel_rt = bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel && FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
...@@ -259,14 +258,22 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -259,14 +258,22 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
if (use_xpu_kp_kernel_debug) { if (use_xpu_kp_kernel_debug) {
VLOG(3) << "xpu_kp using debug mode "; VLOG(3) << "xpu_kp using debug mode ";
} }
if (paddle::platform::is_xpu_place(expected_kernel_key.place_) && bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
(use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug)) { if (is_xpu_kp_support) {
expected_kernel_key.place_ = platform::XPUPlace();
expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP; expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
VLOG(3) << "using XPU KP kernel: " << op.Type() VLOG(3) << "using XPU KP kernel: " << op.Type()
<< ", using_kernel_key:" << expected_kernel_key; << ", using_kernel_key:" << expected_kernel_key;
} }
if (!is_xpu_kp_support &&
(kernel_iter == kernels.end() || is_xpu_unsupport)) {
VLOG(3) << "missing XPU kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
}
#endif #endif
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
......
...@@ -341,7 +341,6 @@ void BuildDygraphPhiKernelContext( ...@@ -341,7 +341,6 @@ void BuildDygraphPhiKernelContext(
} }
for (size_t i = 0; i < attr_names.size(); ++i) { for (size_t i = 0; i < attr_names.size(); ++i) {
VLOG(1) << "############## attr_name: " << i << " : " << attr_names[i];
if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) { if (attr_defs[i].type_index == std::type_index(typeid(phi::ScalarArray))) {
if (attrs.find(attr_names[i]) != if (attrs.find(attr_names[i]) !=
attrs.end()) { // shape is in the attribute attrs.end()) { // shape is in the attribute
......
...@@ -390,8 +390,8 @@ bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins, ...@@ -390,8 +390,8 @@ bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins,
} }
phi::KernelSignature Tracer::GetExpectedKernelSignature( phi::KernelSignature Tracer::GetExpectedKernelSignature(
const std::string& type, const NameVarBaseMap& ins, const std::string& type, const NameTensorMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs) const { const NameTensorMap& outs, framework::AttributeMap attrs) const {
auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false); auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
framework::RuntimeContext ctx({}, {}); framework::RuntimeContext ctx({}, {});
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
...@@ -406,7 +406,7 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature( ...@@ -406,7 +406,7 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
attr_checker == nullptr ? empty_attrs_map attr_checker == nullptr ? empty_attrs_map
: attr_checker->GetDefaultAttrMap(); : attr_checker->GetDefaultAttrMap();
auto dygraph_exe_ctx = auto dygraph_exe_ctx =
imperative::DygraphExecutionContext<imperative::VarBase>( imperative::DygraphExecutionContext<egr::EagerVariable>(
*op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs, *op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs,
default_attrs); default_attrs);
auto* opbase_with_kernel = auto* opbase_with_kernel =
......
...@@ -156,8 +156,8 @@ class Tracer { ...@@ -156,8 +156,8 @@ class Tracer {
} }
phi::KernelSignature GetExpectedKernelSignature( phi::KernelSignature GetExpectedKernelSignature(
const std::string& type, const NameVarBaseMap& ins, const std::string& type, const NameTensorMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs) const; const NameTensorMap& outs, framework::AttributeMap attrs) const;
paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists( paddle::framework::GarbageCollector* MutableGarbageCollectorIfNotExists(
const platform::Place& place); const platform::Place& place);
......
...@@ -1485,6 +1485,13 @@ REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor); ...@@ -1485,6 +1485,13 @@ REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor); REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor);
REGISTER_ACTIVATION_OP(thresholded_relu, ThresholdedRelu, REGISTER_ACTIVATION_OP(thresholded_relu, ThresholdedRelu,
ThresholdedReluFunctor, ThresholdedReluGradFunctor); ThresholdedReluFunctor, ThresholdedReluGradFunctor);
REGISTER_ACTIVATION_OP(hard_shrink, HardShrink, HardShrinkFunctor,
HardShrinkGradFunctor);
REGISTER_ACTIVATION_OP(softshrink, SoftShrink, SoftShrinkFunctor,
SoftShrinkGradFunctor);
REGISTER_ACTIVATION_OP(tanh_shrink, TanhShrink, TanhShrinkFunctor,
TanhShrinkGradFunctor);
REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor);
/* ========================== sigmoid register ============================= /* ========================== sigmoid register =============================
*/ */
...@@ -1626,22 +1633,6 @@ REGISTER_OPERATOR( ...@@ -1626,22 +1633,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(elu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
elu_grad, ops::ELUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ELUGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<float>>,
ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<double>>,
ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* ======================== logit register ============================ /* ======================== logit register ============================
......
...@@ -280,6 +280,15 @@ USE_PHI_FUNCTOR(BRelu) ...@@ -280,6 +280,15 @@ USE_PHI_FUNCTOR(BRelu)
USE_PHI_FUNCTOR(ThresholdedRelu) USE_PHI_FUNCTOR(ThresholdedRelu)
USE_PHI_FUNCTOR(LeakyRelu) USE_PHI_FUNCTOR(LeakyRelu)
USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu)
USE_PHI_FUNCTOR(HardShrink)
USE_PHI_FUNCTOR(SoftShrink)
USE_PHI_FUNCTOR(TanhShrink)
USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
template <typename T> template <typename T>
struct SigmoidGradFunctor : public BaseActivationFunctor<T> { struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
...@@ -393,31 +402,6 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> { ...@@ -393,31 +402,6 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
// silu(x) = x / (1 + exp(-x))
template <typename T>
struct SiluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
out.device(d) = x * temp;
}
};
// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x}))
template <typename T>
struct SiluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) + (-x).exp(); // 1+e^(-x)
auto temp2 = x * (-x).exp(); // x*e^(-x)
dx.device(d) = dout * ((static_cast<T>(1) / temp1) *
(static_cast<T>(1) + (temp2 / temp1)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// Originally: logsigmoid(x) = -log (1 + exp(-x)) // Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick: // For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
...@@ -922,59 +906,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -922,59 +906,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)), x);
}
};
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
dx.device(d) = (out > static_cast<T>(0))
.select(dout, dout * (out + static_cast<T>(alpha)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
dx.device(d) = (x > static_cast<T>(0))
.select(dout, dout * static_cast<T>(alpha) * x.exp());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ELUGradKernel : public framework::OpKernel<T> { class ELUGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -1207,44 +1138,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1207,44 +1138,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad"));
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> { struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha; float alpha;
...@@ -1985,9 +1878,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1985,9 +1878,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
} // namespace paddle } // namespace paddle
#define FOR_EACH_ACTIVATION_OP(__macro) \ #define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(silu, Silu, SiluFunctor, SiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \
...@@ -2000,8 +1891,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2000,8 +1891,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
HardSigmoidGradFunctor); \ HardSigmoidGradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle { namespace paddle {
...@@ -20,6 +21,8 @@ namespace operators { ...@@ -20,6 +21,8 @@ namespace operators {
template <typename T> template <typename T>
class MLUBatchNormOpKernel : public framework::OpKernel<T> { class MLUBatchNormOpKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto &place = ctx.GetPlace(); const auto &place = ctx.GetPlace();
...@@ -68,10 +71,10 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> { ...@@ -68,10 +71,10 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
// alloc memory // alloc memory
y->mutable_data<T>(place); y->mutable_data<T>(place);
mean_out->mutable_data<T>(place); mean_out->mutable_data<MPDType>(place);
variance_out->mutable_data<T>(place); variance_out->mutable_data<MPDType>(place);
saved_mean->mutable_data<T>(place); saved_mean->mutable_data<MPDType>(place);
saved_variance->mutable_data<T>(place); saved_variance->mutable_data<MPDType>(place);
Tensor transformed_x; Tensor transformed_x;
Tensor transformed_y; Tensor transformed_y;
...@@ -132,6 +135,8 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> { ...@@ -132,6 +135,8 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
template <typename T> template <typename T>
class MLUBatchNormGradOpKernel : public framework::OpKernel<T> { class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
using MPDType = typename details::MPTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
...@@ -154,10 +159,10 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> { ...@@ -154,10 +159,10 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<MLUDeviceContext>(); auto &dev_ctx = ctx.template device_context<MLUDeviceContext>();
auto d_x_tmp = auto d_x_tmp =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(x->dims(), dev_ctx); ctx.AllocateTmpTensor<T, MLUDeviceContext>(x->dims(), dev_ctx);
auto scale_grad_tmp = auto scale_grad_tmp = ctx.AllocateTmpTensor<MPDType, MLUDeviceContext>(
ctx.AllocateTmpTensor<T, MLUDeviceContext>(scale->dims(), dev_ctx); scale->dims(), dev_ctx);
auto bias_grad_tmp = auto bias_grad_tmp =
ctx.AllocateTmpTensor<T, MLUDeviceContext>(bias->dims(), dev_ctx); ctx.AllocateTmpTensor<MPDType, MLUDeviceContext>(bias->dims(), dev_ctx);
if (d_x == nullptr) { if (d_x == nullptr) {
d_x = &d_x_tmp; d_x = &d_x_tmp;
...@@ -171,8 +176,8 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> { ...@@ -171,8 +176,8 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
const auto &place = ctx.GetPlace(); const auto &place = ctx.GetPlace();
d_x->mutable_data<T>(place); d_x->mutable_data<T>(place);
d_scale->mutable_data<T>(place); d_scale->mutable_data<MPDType>(place);
d_bias->mutable_data<T>(place); d_bias->mutable_data<MPDType>(place);
use_global_stats = is_test || use_global_stats; use_global_stats = is_test || use_global_stats;
......
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,14 +23,6 @@ namespace operators { ...@@ -21,14 +23,6 @@ namespace operators {
class CumprodOp : public framework::OperatorWithKernel { class CumprodOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cumprod");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cumprod");
ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
}
}; };
class CumprodOpMaker : public framework::OpProtoAndCheckerMaker { class CumprodOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -82,9 +76,12 @@ class CumprodGradOp : public framework::OperatorWithKernel { ...@@ -82,9 +76,12 @@ class CumprodGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(cumprod, CumprodInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker, REGISTER_OPERATOR(cumprod, ops::CumprodOp, ops::CumprodOpMaker,
ops::CumprodGradOpMaker<paddle::framework::OpDesc>, ops::CumprodGradOpMaker<paddle::framework::OpDesc>,
ops::CumprodGradOpMaker<paddle::imperative::OpBase>); ops::CumprodGradOpMaker<paddle::imperative::OpBase>,
CumprodInferShapeFunctor);
REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp); REGISTER_OPERATOR(cumprod_grad, ops::CumprodGradOp);
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/determinant_op.h" #include "paddle/fluid/operators/determinant_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -20,11 +24,6 @@ namespace operators { ...@@ -20,11 +24,6 @@ namespace operators {
class DeterminantOp : public framework::OperatorWithKernel { class DeterminantOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "determinant");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "determinant");
}
}; };
class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker { class DeterminantOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -44,19 +43,6 @@ class DeterminantGradOp : public framework::OperatorWithKernel { ...@@ -44,19 +43,6 @@ class DeterminantGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input",
"DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "DeterminantGradOp");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Input")), "Output",
framework::GradVarName("Input"), "DeterminantGradOp");
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
...@@ -162,19 +148,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer, ...@@ -162,19 +148,17 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(SlogDeterminantGradNoNeedBufferVarsInferer,
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(determinant, DeterminantInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker, REGISTER_OPERATOR(determinant, ops::DeterminantOp, ops::DeterminantOpMaker,
ops::DeterminantGradOpMaker<paddle::framework::OpDesc>, ops::DeterminantGradOpMaker<paddle::framework::OpDesc>,
ops::DeterminantGradOpMaker<paddle::imperative::OpBase>); ops::DeterminantGradOpMaker<paddle::imperative::OpBase>,
DeterminantInferShapeFunctor);
REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp)
REGISTER_OP_CPU_KERNEL(determinant, DECLARE_INFER_SHAPE_FUNCTOR(determinant_grad, DeterminantGradInferShapeFunctor,
ops::DeterminantKernel<plat::CPUDeviceContext, float>, PD_INFER_META(phi::GeneralUnaryGradInferMeta));
ops::DeterminantKernel<plat::CPUDeviceContext, double>); REGISTER_OPERATOR(determinant_grad, ops::DeterminantGradOp,
DeterminantGradInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
determinant_grad, ops::DeterminantGradKernel<plat::CPUDeviceContext, float>,
ops::DeterminantGradKernel<plat::CPUDeviceContext, double>);
REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp, REGISTER_OPERATOR(slogdeterminant, ops::SlogDeterminantOp,
ops::SlogDeterminantOpMaker, ops::SlogDeterminantOpMaker,
......
...@@ -17,14 +17,6 @@ limitations under the License. */ ...@@ -17,14 +17,6 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
determinant, ops::DeterminantKernel<plat::CUDADeviceContext, float>,
ops::DeterminantKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
determinant_grad,
ops::DeterminantGradKernel<plat::CUDADeviceContext, float>,
ops::DeterminantGradKernel<plat::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
slogdeterminant, ops::SlogDeterminantKernel<plat::CUDADeviceContext, float>, slogdeterminant, ops::SlogDeterminantKernel<plat::CUDADeviceContext, float>,
......
...@@ -23,10 +23,13 @@ ...@@ -23,10 +23,13 @@
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/complex_kernel.h" #include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/diag_functor.h" #include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/matrix_inverse.h" #include "paddle/phi/kernels/funcs/matrix_inverse.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/impl/determinant_grad_kernel_impl.h"
#include "paddle/phi/kernels/impl/determinant_kernel_impl.h"
#include "paddle/phi/kernels/math_kernel.h" #include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h"
...@@ -40,232 +43,6 @@ T sign(T val) { ...@@ -40,232 +43,6 @@ T sign(T val) {
return static_cast<T>(T(0) < val) - (val < T(0)); return static_cast<T>(T(0) < val) - (val < T(0));
} }
template <typename T>
class EigenMatrix {};
template <>
class EigenMatrix<float> {
public:
using MatrixType = Eigen::MatrixXf;
};
template <>
class EigenMatrix<double> {
public:
using MatrixType = Eigen::MatrixXd;
};
inline int64_t GetBatchCount(const framework::DDim dims) {
int64_t batch_count = 1;
auto dim_size = dims.size();
PADDLE_ENFORCE_GE(
dim_size, 2,
platform::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
// Cumulative multiplying each dimension until the last 2 to get the batch
// count,
// for example a tensor with shape [3,3,3,3], the batch count of matrices is
// 9.
for (int64_t i = 0; i < dims.size() - 2; i++) {
batch_count *= dims[i];
}
return batch_count;
}
template <typename T>
struct DeterminantFunctor {
void operator()(const Tensor& input, const framework::ExecutionContext ctx,
int64_t rank, int64_t batch_count, Tensor* output) {
std::vector<T> input_vec;
std::vector<T> output_vec;
framework::TensorToVector(input, ctx.device_context(), &input_vec);
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data
typename EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j];
}
}
output_vec.push_back(matrix.determinant());
}
framework::TensorFromVector(output_vec, output);
}
};
template <typename DeviceContext, typename T>
class DeterminantKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input");
auto input_dim = vectorize(input->dims());
auto input_dim_size = input_dim.size();
auto* output = context.Output<framework::Tensor>("Out");
auto batch_count = GetBatchCount(input->dims());
VLOG(2) << "input dim:" << input->dims();
PADDLE_ENFORCE_GE(
input_dim_size, 2,
platform::errors::InvalidArgument(
"the input matrix dimension size should greater than 2."));
PADDLE_ENFORCE_EQ(input_dim[input_dim_size - 1],
input_dim[input_dim_size - 2],
platform::errors::InvalidArgument(
"the input matrix should be square matrix."));
auto rank = input_dim[input_dim_size - 1]; // square matrix length
DeterminantFunctor<T>()(*input, context, rank, batch_count, output);
auto output_dims = phi::slice_ddim(input->dims(), 0, input_dim_size - 2);
if (input_dim_size > 2) {
output->Resize(output_dims);
} else {
// when input is a two-dimension matrix, The det value is a number.
output->Resize({1});
}
VLOG(2) << "output dim:" << output->dims();
}
};
template <typename T>
struct FoundZeroFunctor {
FoundZeroFunctor(const T* x, int64_t numel, bool* res)
: x_(x), numel_(numel), res_(res) {}
HOSTDEVICE void operator()(size_t idx) const {
if (*res_ || idx >= static_cast<size_t>(numel_)) {
// founded zero number
return;
}
*res_ = (x_[idx] == static_cast<T>(0));
}
const T* x_;
int64_t numel_;
bool* res_;
};
template <typename DeviceContext, typename T>
inline bool CheckMatrixInvertible(const framework::ExecutionContext& ctx,
const framework::Tensor* det) {
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto numel = det->numel();
framework::Tensor dev_tensor;
auto* data = dev_tensor.mutable_data<bool>({1}, ctx.GetPlace());
// set false
phi::funcs::SetConstant<DeviceContext, bool> zero;
zero(dev_ctx, &dev_tensor, false);
// find whether zero
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
FoundZeroFunctor<T> functor(det->data<T>(), numel, data);
for_range(functor);
// copy to host
dev_ctx.Wait();
framework::Tensor cpu_tensor;
framework::TensorCopy(dev_tensor, platform::CPUPlace(), &cpu_tensor);
// if founded zero, the matrix is not invertible
// else the matrix is invertible
auto* res = cpu_tensor.data<bool>();
return !(*res);
}
template <typename DeviceContext, typename T>
class DeterminantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
const auto* input = context.Input<framework::Tensor>("Input");
const auto* det = context.Input<framework::Tensor>("Out");
const auto* grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* ddet =
context.Output<framework::Tensor>(framework::GradVarName("Input"));
auto input_dims_size = input->dims().size();
if (input_dims_size > 2) {
PADDLE_ENFORCE_EQ(
grad->dims().size() + 2, input_dims_size,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else if (input_dims_size == 2) {
// input dims size 2 and grad dims size 1 is possible
PADDLE_ENFORCE_EQ(
grad->dims().size(), 1,
platform::errors::InvalidArgument(
"The grad tensor of det dims size should 2 less than"
" input tensor's, but here differ %d",
input_dims_size - grad->dims().size()));
} else {
// checked in forward, pass
}
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// Check Whether the matrix is invertible
// (matrix A not invertible) == (det(A)=0)
if (!CheckMatrixInvertible<DeviceContext, T>(context, det)) {
// The matrix is not invertible
VLOG(3) << "The input matrix not invertible!";
ddet->Resize(input->dims());
phi::Full<T>(dev_ctx, phi::vectorize(input->dims()), static_cast<T>(0.0f),
ddet);
return;
}
// The matrix is invertible
// let |A| = Determinant(A)
// Ref to https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// we set d|A| = unsqueeze(dA * |A|, [-1, -2]) * inverse(A).transpose(-2,
// -1)
// First: inverse(A)
framework::Tensor inverse_A;
// A must be square matrices!
inverse_A.Resize(input->dims());
inverse_A.mutable_data<T>(context.GetPlace());
phi::funcs::MatrixInverseFunctor<DeviceContext, T> mat_inv;
mat_inv(orig_dev_ctx, *input, &inverse_A);
VLOG(3) << "inverse(A) dims: " << inverse_A.dims();
// Second: inverse(A).transpose(-2, -1)
framework::Tensor transpose_inverse_A =
phi::TransposeLast2Dim<T>(dev_ctx, inverse_A);
VLOG(3) << "(dA * |A|).transpose(-2, -1) dims: "
<< transpose_inverse_A.dims();
// Third: dA * |A|
auto mul_dA_detA = phi::Multiply<T>(dev_ctx, *grad, *det);
VLOG(3) << "dA * |A| dims: " << mul_dA_detA.dims();
// Fourth: unsqueeze(dA * |A|, [-1, -2])
auto unsqueeze1 = phi::funcs::Unsqueeze(mul_dA_detA, -1);
auto unsqueeze2 = phi::funcs::Unsqueeze(unsqueeze1, -2);
VLOG(3) << "unsqueezed(dA * |A|) dims: " << unsqueeze2.dims();
// Finally: unsqueeze(dA * |A|) * inverse(A)
auto res = phi::Multiply<T>(dev_ctx, unsqueeze2, transpose_inverse_A);
VLOG(3) << "unsqueeze(dA * |A|) * inverse(A) dims: " << res.dims();
framework::TensorCopy(res, context.GetPlace(), ddet);
ddet->Resize(input->dims());
VLOG(3) << "d|A| dims: " << ddet->dims();
}
};
template <typename T> template <typename T>
struct SlogDeterminantFunctor { struct SlogDeterminantFunctor {
void operator()(const Tensor& input, const framework::ExecutionContext ctx, void operator()(const Tensor& input, const framework::ExecutionContext ctx,
...@@ -280,7 +57,7 @@ struct SlogDeterminantFunctor { ...@@ -280,7 +57,7 @@ struct SlogDeterminantFunctor {
auto end_iter = input_vec.begin() + (i + 1) * rank * rank; auto end_iter = input_vec.begin() + (i + 1) * rank * rank;
std::vector<T> sub_vec(begin_iter, std::vector<T> sub_vec(begin_iter,
end_iter); // get every square matrix data end_iter); // get every square matrix data
typename EigenMatrix<T>::MatrixType matrix(rank, rank); typename phi::detail::EigenMatrix<T>::MatrixType matrix(rank, rank);
for (int64_t i = 0; i < rank; ++i) { for (int64_t i = 0; i < rank; ++i) {
for (int64_t j = 0; j < rank; ++j) { for (int64_t j = 0; j < rank; ++j) {
matrix(i, j) = sub_vec[rank * i + j]; matrix(i, j) = sub_vec[rank * i + j];
...@@ -311,7 +88,7 @@ class SlogDeterminantKernel : public framework::OpKernel<T> { ...@@ -311,7 +88,7 @@ class SlogDeterminantKernel : public framework::OpKernel<T> {
auto input_dim_size = input_dim.size(); auto input_dim_size = input_dim.size();
auto* output = context.Output<framework::Tensor>("Out"); auto* output = context.Output<framework::Tensor>("Out");
auto batch_count = GetBatchCount(input->dims()); auto batch_count = phi::detail::GetBatchCount(input->dims());
VLOG(2) << "input dim:" << input->dims(); VLOG(2) << "input dim:" << input->dims();
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
input_dim_size, 2, input_dim_size, 2,
...@@ -370,7 +147,9 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> { ...@@ -370,7 +147,9 @@ class SlogDeterminantGradKernel : public framework::OpKernel<T> {
// (matrix A not invertible) == (absslogdet(A)=0) // (matrix A not invertible) == (absslogdet(A)=0)
auto slogdet_vec = slogdet->Split(1, 0); auto slogdet_vec = slogdet->Split(1, 0);
auto absslogdet_val = slogdet_vec[0]; auto absslogdet_val = slogdet_vec[0];
if (!CheckMatrixInvertible<DeviceContext, T>(context, &absslogdet_val)) { if (!phi::detail::CheckMatrixInvertible<
T, typename framework::ConvertToPhiContext<DeviceContext>::TYPE>(
dev_ctx, &absslogdet_val)) {
// The matrix is not invertible // The matrix is not invertible
VLOG(3) << "The input matrix not invertible!"; VLOG(3) << "The input matrix not invertible!";
dslogdet->Resize(input->dims()); dslogdet->Resize(input->dims());
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,100 +12,8 @@ ...@@ -12,100 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto tz = phi::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler handler(
tz, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
if (dx) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
if (dy) {
// Direct copy
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} else {
// Broadcasting
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
...@@ -116,6 +24,8 @@ REGISTER_OP_KERNEL( ...@@ -116,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_add>, ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>) ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_add>)
REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::EltwiseAddMKLDNNGradKernel<paddle::platform::bfloat16>, elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseAddMKLDNNGradKernel<float>) ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_add>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_add>)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseDivMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Input<framework::Tensor>("Out");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout / y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = -dout * out / y
platform::BinaryMKLDNNHandler<T> y_handler(
dnnl::algorithm::binary_div, axis, mkldnn_engine, ctx.GetPlace(), y,
y, nullptr, 1.0f, 1.0f, 1.0f);
const auto y_memory = y_handler.AcquireSrcMemory(y);
dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div, y_memory->get_desc());
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, out, nullptr, -1.0f, 1.0f, 1.0f, po);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_out_memory = handler.AcquireSecondSrcMemory(out);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_out_memory},
{DNNL_ARG_DST, *dst_dy_memory},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, *y_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
// TODO(piotrekobi) add int8, uint8 support
REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(elementwise_div, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_div>, ops::EltwiseMKLDNNKernel<float, dnnl::algorithm::binary_div>,
ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16, ops::EltwiseMKLDNNKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_div>) dnnl::algorithm::binary_div>)
REGISTER_OP_KERNEL(elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::EltwiseDivMKLDNNGradKernel<paddle::platform::bfloat16>, elementwise_div_grad, MKLDNN, paddle::platform::CPUPlace,
ops::EltwiseDivMKLDNNGradKernel<float>) ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_div>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_div>)
...@@ -15,20 +15,35 @@ ...@@ -15,20 +15,35 @@
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::DataLayout;
using framework::Tensor;
using dnnl::memory; using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using dnnl::stream; using dnnl::stream;
using framework::DataLayout;
using framework::Tensor;
inline std::vector<int64_t> CalculateBroadcastedDims(const Tensor* x,
const Tensor* y) {
const auto src_tz = phi::vectorize(x->dims());
const auto dst_tz = phi::vectorize(y->dims());
size_t j = 0;
std::vector<int64_t> dst_tz_ex(src_tz.size(), 1);
for (size_t i = 0; i < src_tz.size(); ++i) {
dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++];
if (j == dst_tz.size()) break;
}
return dst_tz_ex;
}
template <typename T, dnnl::algorithm BINARY_OP> template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> { class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
...@@ -103,7 +118,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -103,7 +118,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
// operation. // operation.
const bool reuse_x_memopry = const bool reuse_x_memopry =
x->numel() == z->numel() && x->IsSharedBufferWith(*z); x->numel() == z->numel() && x->IsSharedBufferWith(*z);
std::shared_ptr<dnnl::memory> dst_memory = nullptr; std::shared_ptr<dnnl::memory> dst_memory;
if (reuse_x_memopry) { if (reuse_x_memopry) {
dst_memory = src_x_memory; dst_memory = src_x_memory;
// NOTE(chenfeiyu): when the output reuses memory from other tensor rather // NOTE(chenfeiyu): when the output reuses memory from other tensor rather
...@@ -135,19 +150,193 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -135,19 +150,193 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
} }
}; };
inline std::vector<int64_t> CalculateBroadcastedDims(const Tensor* x, template <typename T, dnnl::algorithm BINARY_OP>
const Tensor* y) { class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
const auto src_tz = phi::vectorize(x->dims()); public:
const auto dst_tz = phi::vectorize(y->dims()); void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
size_t j = 0; auto& dev_ctx =
std::vector<int64_t> dst_tz_ex(src_tz.size(), 1); ctx.template device_context<platform::MKLDNNDeviceContext>();
for (size_t i = 0; i < src_tz.size(); ++i) { const auto& onednn_engine = dev_ctx.GetEngine();
dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++];
if (j == dst_tz.size()) break; auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Input<Tensor>("Out");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
auto tz = phi::vectorize<int64_t>(dout->dims());
auto proto_type_dout = framework::TransToProtoVarType(dout->dtype());
platform::ReorderMKLDNNHandler reorder_handler(
tz, proto_type_dout, framework::ToMKLDNNDataType(proto_type_dout),
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
std::shared_ptr<dnnl::memory> dst_memory;
// elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add ||
BINARY_OP == dnnl::algorithm::binary_sub) {
dst_memory = reorder_handler.AcquireDstMemory(dx, dout->format(),
ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *dst_memory);
} }
return dst_tz_ex; // elementwise_mul & elementwise_div
} else {
platform::BinaryMKLDNNHandler<T> binary_handler(
BINARY_OP, axis, onednn_engine, ctx.GetPlace(), dout, y, dx, 1.0f,
1.0f, 1.0f);
const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout);
const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y);
dst_memory = binary_handler.AcquireDstMemory(dx);
const auto binary_prim = binary_handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory}};
binary_prim->execute(astream, args);
}
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
if (dy) {
dnnl::primitive_attr broadcast_reduction_attr;
std::shared_ptr<dnnl::memory> broadcast_src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
// elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add ||
BINARY_OP == dnnl::algorithm::binary_sub) {
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dy, dout->format(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr;
std::vector<float> scales(1);
scales[0] = (BINARY_OP == dnnl::algorithm::binary_add) ? 1 : -1;
reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
dst_memory = reorder_dst_memory_p;
} else {
broadcast_src_memory = reorder_src_memory_p;
}
}
// elementwise_mul & elementwise_div
else {
std::unordered_map<int, dnnl::memory> args;
std::shared_ptr<dnnl::binary> binary_prim;
std::shared_ptr<dnnl::memory> post_op_memory;
std::shared_ptr<dnnl::memory> src_0_memory;
std::shared_ptr<dnnl::memory> src_1_memory;
platform::BinaryMKLDNNHandler<T> binary_handler(
dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(),
dout, x, nullptr, 1.0f, 1.0f, 1.0f);
src_1_memory = binary_handler.AcquireSecondSrcMemory(x);
if (BINARY_OP == dnnl::algorithm::binary_div) {
platform::BinaryMKLDNNHandler<T> post_op_binary_handler(
dnnl::algorithm::binary_div, axis, onednn_engine, ctx.GetPlace(),
y, y, nullptr, 1.0f, 1.0f, 1.0f);
post_op_memory = post_op_binary_handler.AcquireSrcMemory(y);
dnnl::post_ops po;
po.append_binary(dnnl::algorithm::binary_div,
post_op_memory->get_desc());
binary_handler = platform::BinaryMKLDNNHandler<T>(
dnnl::algorithm::binary_mul, axis, onednn_engine, ctx.GetPlace(),
dout, out, nullptr, -1.0f, 1.0f, 1.0f, po);
src_1_memory = binary_handler.AcquireSecondSrcMemory(out);
}
src_0_memory = binary_handler.AcquireSrcMemory(dout);
const auto dst_dy_memory = (dout->dims() == dy->dims())
? binary_handler.AcquireDstMemory(dy)
: binary_handler.AcquireDstMemory();
binary_prim = binary_handler.AcquireForwardPrimitive();
args = {{DNNL_ARG_SRC_0, *src_0_memory},
{DNNL_ARG_SRC_1, *src_1_memory},
{DNNL_ARG_DST, *dst_dy_memory}};
if (BINARY_OP == dnnl::algorithm::binary_div)
args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*post_op_memory});
binary_prim->execute(astream, args);
broadcast_src_memory = dst_dy_memory;
dst_memory = dst_dy_memory;
}
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
if (dout->dims() != dy->dims()) {
// Broadcasting
if (BINARY_OP == dnnl::algorithm::binary_sub) {
dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
broadcast_reduction_attr.set_post_ops(po);
}
platform::ReductionMKLDNNHandler<T> reduction_handler(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy),
broadcast_reduction_attr);
dst_memory = reduction_handler.AcquireDstMemory(dy);
auto reduction_p = reduction_handler.AcquireForwardPrimitive();
reduction_p->execute(astream, {
{DNNL_ARG_SRC, *broadcast_src_memory},
{DNNL_ARG_DST, *dst_memory},
});
astream.wait();
dy->set_format(platform::GetMKLDNNFormat(dst_memory->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (dx) {
// dx = dout*y
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, y, dx, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_y_memory = handler.AcquireSecondSrcMemory(y);
const auto dst_dx_memory = handler.AcquireDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_dx_memory}};
binary_prim->execute(astream, args);
astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*dst_dx_memory));
}
if (dy) {
// dy = dout*x
// Handler is having nullptr passed instead of output tensor as
// we want Dst buffer to be allocated by oneDNN not to use Tensor
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_mul, axis, mkldnn_engine, ctx.GetPlace(),
dout, x, nullptr, 1.0f, 1.0f, 1.0f);
const auto src_dout_memory = handler.AcquireSrcMemory(dout);
const auto src_x_memory = handler.AcquireSecondSrcMemory(x);
// If broadcasting is in use then let's write to temporary
// buffer allocated by oneDNN
const auto dst_dy_memory = (dout->dims() == dy->dims())
? handler.AcquireDstMemory(dy)
: handler.AcquireDstMemory();
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_dout_memory},
{DNNL_ARG_SRC_1, *src_x_memory},
{DNNL_ARG_DST, *dst_dy_memory}};
binary_prim->execute(astream, args);
astream.wait();
dy->set_layout(framework::DataLayout::kMKLDNN);
// Reduction is needed for broadcasting scenario
if (dout->dims() != dy->dims()) {
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, mkldnn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy));
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
// As source we use mem object with results from binary operation
reduction_p->execute(astream, {{DNNL_ARG_SRC, *dst_dy_memory},
{DNNL_ARG_DST, *dy_memory_p}});
astream.wait();
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
} else {
dy->set_format(platform::GetMKLDNNFormat(*dst_dy_memory));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
...@@ -132,6 +24,8 @@ REGISTER_OP_KERNEL( ...@@ -132,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>, ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>) ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_mul>)
REGISTER_OP_KERNEL(elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::EltwiseMulMKLDNNGradKernel<paddle::platform::bfloat16>, elementwise_mul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseMulMKLDNNGradKernel<float>) ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_mul>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_mul>)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -13,113 +12,7 @@ ...@@ -13,113 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" #include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h"
namespace paddle {
namespace framework {
class ExecutionContext;
} // namespace framework
namespace platform {
class CPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
template <typename T>
class EltwiseSubMKLDNNGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto tz = phi::vectorize<int64_t>(dout->dims());
memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler handler(
tz, framework::TransToProtoVarType(dout->dtype()), dout_type,
onednn_engine);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
auto reorder_src_memory_p = handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));
if (dx) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dx, dout->format(), ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
dx->set_layout(DataLayout::kMKLDNN);
dx->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
if (dy) {
// Direct copy
if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p =
handler.AcquireDstMemory(dy, dout->format(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr;
std::vector<float> scales = {-1};
reorder_attr.set_output_scales(0, scales);
auto reorder_p = std::make_shared<dnnl::reorder>(
*(reorder_src_memory_p), *(reorder_dst_memory_p), reorder_attr);
platform::RecordEvent record_reorder(
"int_reorder", platform::TracerEventType::UserDefined, 2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p,
*reorder_dst_memory_p);
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
} else {
// Broadcasting
dnnl::post_ops po;
po.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, -1.0f, 0);
dnnl::primitive_attr attr;
attr.set_post_ops(po);
platform::ReductionMKLDNNHandler<T> handler_sum(
dnnl::algorithm::reduction_sum, 0.0f, 0.0f, onednn_engine,
ctx.GetPlace(), dout, dy, CalculateBroadcastedDims(dout, dy), attr);
auto dy_memory_p = handler_sum.AcquireDstMemory(dy);
auto reduction_p = handler_sum.AcquireForwardPrimitive();
reduction_p->execute(astream, {
{DNNL_ARG_SRC, *reorder_src_memory_p},
{DNNL_ARG_DST, *dy_memory_p},
});
astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
dy->set_format(
platform::GetMKLDNNFormat(dy_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims()))));
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
...@@ -131,6 +24,8 @@ REGISTER_OP_KERNEL( ...@@ -131,6 +24,8 @@ REGISTER_OP_KERNEL(
ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_sub>, ops::EltwiseMKLDNNKernel<int8_t, dnnl::algorithm::binary_sub>,
ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_sub>) ops::EltwiseMKLDNNKernel<uint8_t, dnnl::algorithm::binary_sub>)
REGISTER_OP_KERNEL(elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::EltwiseSubMKLDNNGradKernel<paddle::platform::bfloat16>, elementwise_sub_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::EltwiseSubMKLDNNGradKernel<float>) ops::EltwiseMKLDNNGradKernel<paddle::platform::bfloat16,
dnnl::algorithm::binary_sub>,
ops::EltwiseMKLDNNGradKernel<float, dnnl::algorithm::binary_sub>)
...@@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,58 +31,6 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -26,58 +31,6 @@ class GatherOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of GatherOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of GatherOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of GatherOp should not be null."));
auto index_dims = ctx->GetInputDim("Index");
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1], 1,
platform::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(), 1,
platform::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
auto axis = ctx->Attrs().Get<int>("axis");
auto input_dim = ctx->GetInputDim("X");
if (ctx->HasInput("Axis") || axis == 0) {
// if HasInput("Axis"), we can not obtain correct shape of output
int batch_size = index_dims[0];
framework::DDim output_dims(input_dim);
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
int index_size = index_dims[0];
std::vector<int> out_dim_vec;
for (int i = 0; i < axis; i++) {
out_dim_vec.push_back(input_dim[i]);
}
out_dim_vec.push_back(index_size);
for (int i = axis + 1; i < input_dim.size(); i++) {
out_dim_vec.push_back(input_dim[i]);
}
auto output_dims = phi::make_ddim(out_dim_vec);
ctx->SetOutputDim("Out", output_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -100,11 +53,6 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -100,11 +53,6 @@ class GatherGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*-->*/ framework::GradVarName("X"));
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -193,22 +141,18 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X"); ...@@ -193,22 +141,18 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(GatherGradNoNeedBufferVarInferer, "X");
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(gather, GatherInferShapeFunctor,
PD_INFER_META(phi::GatherInferMeta));
REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker, REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::framework::OpDesc>, ops::GatherGradOpMaker<paddle::framework::OpDesc>,
ops::GatherGradOpMaker<paddle::imperative::OpBase>); ops::GatherGradOpMaker<paddle::imperative::OpBase>,
GatherInferShapeFunctor);
DECLARE_INFER_SHAPE_FUNCTOR(gather_grad, GatherGradInferShapeFunctor,
PD_INFER_META(phi::GeneralUnaryGradInferMeta));
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp, REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInferer); ops::GatherGradNoNeedBufferVarInferer,
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>, GatherGradInferShapeFunctor);
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
ops::GatherOpKernel<int64_t>,
ops::GatherOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>,
ops::GatherGradientOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_VERSION(gather) REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC", .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput( paddle::framework::compatible::OpVersionDesc().NewInput(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/gather_op.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
namespace paddle {
namespace operators {
template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) {
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type =
framework::TransToProtoVarType(axis_tensor->dtype());
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT16) {
axis = static_cast<int>(cpu_axis.data<int16_t>()[0]);
}
}
const auto &place = ctx.GetPlace();
const auto &index_type = framework::TransToProtoVarType(index->dtype());
const auto &dev_ctx = ctx.cuda_device_context();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GatherV2CUDAFunction<T, int32_t>(x, index, axis, output,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GatherV2CUDAFunction<T, int64_t>(x, index, axis, output,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GatherV2CUDAFunction<T, int16_t>(x, index, axis, output,
dev_ctx);
}
return;
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GPUGather<T, int64_t>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GPUGather<T, int16_t>(dev_ctx, *x, *index, output);
}
}
};
template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
Tensor cpu_axis;
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type =
framework::TransToProtoVarType(axis_tensor->dtype());
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
}
const auto &dev_ctx = ctx.cuda_device_context();
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX,
dev_ctx);
}
return;
}
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GPUScatterAssign<T, int>(dev_ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GPUScatterAssign<T, int64_t>(dev_ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>,
ops::GatherOpCUDAKernel<int16_t>,
ops::GatherOpCUDAKernel<plat::float16>,
ops::GatherOpCUDAKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
ops::GatherGradOpCUDAKernel<double>,
ops::GatherGradOpCUDAKernel<int64_t>,
ops::GatherGradOpCUDAKernel<int>,
ops::GatherGradOpCUDAKernel<plat::float16>,
ops::GatherGradOpCUDAKernel<plat::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/scatter.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class GatherOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis_tensor->dtype();
if (axis_type == phi::DataType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == phi::DataType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
}
}
const auto &index_type = index->dtype();
auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
if (axis != 0) {
if (index_type == phi::DataType::INT32) {
phi::funcs::GatherV2Function<T, int32_t>(dev_ctx, x, index, axis,
output);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2Function<T, int64_t>(dev_ctx, x, index, axis,
output);
}
return;
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
if (index_type == phi::DataType::INT32) {
phi::funcs::CPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::CPUGather<T, int64_t>(dev_ctx, *x, *index, output);
}
}
};
template <typename T>
class GatherGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis_tensor->dtype();
if (axis_type == phi::DataType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == phi::DataType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
}
}
const auto &index_type = index->dtype();
auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
if (axis != 0) {
if (index_type == phi::DataType::INT32) {
phi::funcs::GatherV2GradFunction<T, int32_t>(dev_ctx, dO, index, axis,
dX);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2GradFunction<T, int64_t>(dev_ctx, dO, index, axis,
dX);
}
return;
}
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *dev_ctx.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
bool overwrite = ctx.Attr<bool>("overwrite");
if (index_type == phi::DataType::INT32) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, *dO, *index, dX);
} else {
phi::funcs::ScatterAssignAdd<T, int32_t>(dev_ctx, *dO, *index, dX);
}
} else if (index_type == phi::DataType::INT64) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, *dO, *index, dX);
} else {
phi::funcs::ScatterAssignAdd<T, int64_t>(dev_ctx, *dO, *index, dX);
}
}
}
};
} // namespace operators
} // namespace paddle
...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
...@@ -24,16 +24,15 @@ limitations under the License. */ ...@@ -24,16 +24,15 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/gather_op.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(gather); USE_OP_ITSELF(gather);
USE_OP_DEVICE_KERNEL(gather, NPU); USE_OP_DEVICE_KERNEL(gather, NPU);
USE_OP(gather_grad); USE_OP_ITSELF(gather_grad);
USE_OP_DEVICE_KERNEL(gather_grad, NPU); USE_OP_DEVICE_KERNEL(gather_grad, NPU);
template <typename T> template <typename T>
......
...@@ -13,15 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/gather_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class GatherOpXPUKernel : public framework::OpKernel<T> { class GatherOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
......
...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/grid_sampler_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
...@@ -229,15 +229,6 @@ REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker, ...@@ -229,15 +229,6 @@ REGISTER_OPERATOR(grid_sampler, ops::GridSampleOp, ops::GridSampleOpMaker,
ops::GridSampleGradMaker<paddle::imperative::OpBase>); ops::GridSampleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad); REGISTER_OPERATOR(grid_sampler_grad, ops::GridSampleOpGrad);
REGISTER_OP_CPU_KERNEL(
grid_sampler,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
grid_sampler_grad,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::GridSampleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_VERSION(grid_sampler) REGISTER_OP_VERSION(grid_sampler)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/grid_sampler_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
static __forceinline__ __device__ bool in_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename T>
static __forceinline__ __device__ void atomic_add(T* data, int h, int w, int sH,
int sW, int H, int W,
T delta) {
if (in_bounds(h, w, H, W)) {
platform::CudaAtomicAdd(data + h * sH + w * sW, delta);
}
}
template <typename T>
static __forceinline__ __device__ T _unnormalize(T coord, int size,
bool align_corners) {
if (align_corners) {
return ((coord + 1.f) / 2) * (size - 1);
} else {
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T clip_indexes(T in, int max_value) {
return min(static_cast<T>(max_value), max(in, static_cast<T>(0)));
}
template <typename T>
static __forceinline__ __device__ T reflect_indexes(T in, int twice_low,
int twice_high) {
if (twice_low == twice_high) {
return static_cast<T>(0);
}
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = fabs(in - min);
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T compute_positions(T coord, int size,
PaddingMode padding_mode,
bool align_corners) {
coord = _unnormalize<T>(coord, size, align_corners);
if (padding_mode == PaddingMode::border) {
coord = clip_indexes(coord, size - 1);
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = reflect_indexes(coord, 0, 2 * (size - 1));
} else {
coord = reflect_indexes(coord, -1, 2 * size - 1);
}
coord = clip_indexes(coord, size - 1);
}
return coord;
}
template <typename T>
static __forceinline__ __device__ T _unnormalize_with_mask(T coord, int size,
bool align_corners,
T* grad_in) {
if (align_corners) {
*grad_in = static_cast<T>(size - 1) / 2;
return ((coord + 1.f) / 2) * (size - 1);
} else {
*grad_in = static_cast<T>(size) / 2;
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T clip_indexes_with_mask(T in, int clip_limit,
T* grad_in) {
if (in <= static_cast<T>(0)) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
} else {
T max = static_cast<T>(clip_limit - 1);
if (in >= max) {
*grad_in = static_cast<T>(0);
return max;
} else {
*grad_in = static_cast<T>(1);
return in;
}
}
}
template <typename T>
static __forceinline__ __device__ T
reflect_indexes_with_mask(T in, int twice_low, int twice_high, T* grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
}
int grad_in_mult_;
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<T>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<T>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<T>(-grad_in_mult_);
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T
compute_positions_with_mask(T coord, int size, PaddingMode padding_mode,
bool align_corners, T* grad_in) {
T grad_clip, grad_refl;
coord = _unnormalize_with_mask<T>(coord, size, align_corners, grad_in);
if (padding_mode == PaddingMode::border) {
coord = clip_indexes_with_mask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = reflect_indexes_with_mask(coord, 0, 2 * (size - 1), &grad_refl);
} else {
coord = reflect_indexes_with_mask(coord, -1, 2 * size - 1, &grad_refl);
}
coord = clip_indexes_with_mask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
return coord;
}
template <typename T>
__global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
int out_h, int out_w, int in_h,
int in_w, const T* input, const T* grid,
T* output, const Mode mode,
const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int out_sN = out_c * out_h * out_w;
int out_sC = out_h * out_w;
int out_sH = out_w;
int out_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
ix = compute_positions(ix, in_w, padding_mode, align_corners);
iy = compute_positions(iy, in_h, padding_mode, align_corners);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
*out_ptr_NCHW = static_cast<T>(0);
if (in_bounds(iy_nw, ix_nw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (in_bounds(iy_ne, ix_ne, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (in_bounds(iy_sw, ix_sw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (in_bounds(iy_se, ix_se, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
}
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (in_bounds(iy_nearest, ix_nearest, in_h, in_w)) {
*out_ptr_NCHW =
input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCHW = static_cast<T>(0);
}
}
}
}
}
template <typename T>
class GridSampleOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.cuda_device_context();
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode_s = ctx.Attr<std::string>("padding_mode");
auto mode_s = ctx.Attr<std::string>("mode");
PaddingMode padding_mode;
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
}
if (mode_s == "nearest") {
mode = Mode::nearest;
} else {
mode = Mode::bilinear;
}
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h
<< "; out_w: " << out_w;
auto* output = ctx.Output<Tensor>("Output");
auto* output_data = output->mutable_data<T>(ctx.GetPlace());
VLOG(3) << "out dims: " << output->dims()[0] << "; " << output->dims()[1]
<< "; " << output->dims()[2] << "; " << output->dims()[3];
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, count);
grid_sample_cuda_kernel<
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count, n, c, out_h, out_w, in_h, in_w, input->data<T>(),
grid->data<T>(), output_data, mode, padding_mode, align_corners);
}
};
template <typename T>
__global__ void grid_sampler_cuda_backward_kernel(
const int nthreads, const T* grad_output, const T* input, const T* grid,
int n, int out_c, int out_h, int out_w, int in_h, int in_w, T* grad_input,
T* grad_grid, const Mode mode, const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;
int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int gOut_sN = out_c * out_h * out_w;
int gOut_sC = out_h * out_w;
int gOut_sH = out_w;
int gOut_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
T gix_mult, giy_mult;
ix = compute_positions_with_mask(ix, in_w, padding_mode, align_corners,
&gix_mult);
iy = compute_positions_with_mask(iy, in_h, padding_mode, align_corners,
&giy_mult);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);
T gix = static_cast<T>(0), giy = static_cast<T>(0);
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
int inp_offset_NC = n * inp_sN;
for (int c = 0; c < out_c; ++c, inp_offset_NC += inp_sC,
gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) {
T gOut = grad_output[gOut_offset];
atomic_add(gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w,
nw * gOut);
atomic_add(gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w,
ne * gOut);
atomic_add(gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w,
sw * gOut);
atomic_add(gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w,
se * gOut);
if (in_bounds(iy_nw, ix_nw, in_h, in_w)) {
T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW];
gix -= nw_val * (iy_se - iy) * gOut;
giy -= nw_val * (ix_se - ix) * gOut;
}
if (in_bounds(iy_ne, ix_ne, in_h, in_w)) {
T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW];
gix += ne_val * (iy_sw - iy) * gOut;
giy -= ne_val * (ix - ix_sw) * gOut;
}
if (in_bounds(iy_sw, ix_sw, in_h, in_w)) {
T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW];
gix -= sw_val * (iy - iy_ne) * gOut;
giy += sw_val * (ix_ne - ix) * gOut;
}
if (in_bounds(iy_se, ix_se, in_h, in_w)) {
T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW];
gix += se_val * (iy - iy_nw) * gOut;
giy += se_val * (ix - ix_nw) * gOut;
}
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
for (int c = 0; c < out_c;
++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) {
atomic_add(gInp_ptr_NC, iy_nearest, ix_nearest, inp_sH, inp_sW, in_h,
in_w, grad_output[gOut_offset]);
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
}
}
}
}
template <typename T>
class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.cuda_device_context();
auto align_corners = ctx.Attr<bool>("align_corners");
auto padding_mode_s = ctx.Attr<std::string>("padding_mode");
auto mode_s = ctx.Attr<std::string>("mode");
PaddingMode padding_mode;
Mode mode;
if (padding_mode_s == "border") {
padding_mode = PaddingMode::border;
} else if (padding_mode_s == "reflection") {
padding_mode = PaddingMode::reflect;
} else {
padding_mode = PaddingMode::zeros;
}
if (mode_s == "nearest") {
mode = Mode::nearest;
} else {
mode = Mode::bilinear;
}
auto* input = ctx.Input<Tensor>("X");
auto* grid = ctx.Input<Tensor>("Grid");
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int n = grid->dims()[0];
const int out_h = grid->dims()[1];
const int out_w = grid->dims()[2];
const int c = input->dims()[1];
const int in_h = input->dims()[2];
const int in_w = input->dims()[3];
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
input_grad->mutable_data<T>(ctx.GetPlace());
phi::funcs::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
input_grad, static_cast<T>(0));
T* grid_grad_data = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
}
int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, count);
grid_sampler_cuda_backward_kernel<
T><<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
padding_mode, align_corners);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(grid_sampler, ops::GridSampleOpCUDAKernel<float>,
ops::GridSampleOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(grid_sampler_grad,
ops::GridSampleGradOpCUDAKernel<float>,
ops::GridSampleGradOpCUDAKernel<double>);
此差异已折叠。
...@@ -13,8 +13,13 @@ ...@@ -13,8 +13,13 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/index_select_op.h" #include "paddle/fluid/operators/index_select_op.h"
#include <memory> #include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/infermeta/binary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -24,52 +29,6 @@ class IndexSelectOp : public framework::OperatorWithKernel { ...@@ -24,52 +29,6 @@ class IndexSelectOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Index"), true,
platform::errors::InvalidArgument(
"Input(Index) of IndexSelectOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(Out) of IndexSelectOp should not be null."));
auto input_dim = ctx->GetInputDim("X");
auto index_dim = ctx->GetInputDim("Index");
auto dim = ctx->Attrs().Get<int>("dim");
PADDLE_ENFORCE_EQ(
dim < input_dim.size() && dim >= (0 - input_dim.size()), true,
platform::errors::OutOfRange(
"Attr(dim) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(dim) = %d.",
input_dim.size(), input_dim.size() - 1, dim));
PADDLE_ENFORCE_EQ(
index_dim.size() == 1 || (index_dim.size() == 2 && index_dim[1] == 1),
true, platform::errors::InvalidArgument(
"The 'shape' of Input(Index) must be 1-D tensor. "
"But received: the 'shape' of Input(Index) is [%s], "
"the dimension of Input(Index) is [%d].",
index_dim, index_dim.size()));
PADDLE_ENFORCE_EQ(index_dim[0] != 0, true,
platform::errors::InvalidArgument(
"The length of Input(Index) can't be 0."));
auto output_dim = phi::vectorize(input_dim);
if (dim < 0) {
dim += input_dim.size();
}
output_dim[dim] = index_dim[0];
ctx->SetOutputDim("Out", phi::make_ddim(output_dim));
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -148,20 +107,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer, ...@@ -148,20 +107,11 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERER(IndexSelectGradNoNeedBufferVarsInferer,
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(index_select, IndexSelectInferShapeFunctor,
PD_INFER_META(phi::IndexSelectInferMeta));
REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker, REGISTER_OPERATOR(index_select, ops::IndexSelectOp, ops::IndexSelectOpMaker,
ops::IndexSelectGradMaker<paddle::framework::OpDesc>, ops::IndexSelectGradMaker<paddle::framework::OpDesc>,
ops::IndexSelectGradMaker<paddle::imperative::OpBase>); ops::IndexSelectGradMaker<paddle::imperative::OpBase>,
IndexSelectInferShapeFunctor);
REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp, REGISTER_OPERATOR(index_select_grad, ops::IndexSelectGradOp,
ops::IndexSelectGradNoNeedBufferVarsInferer); ops::IndexSelectGradNoNeedBufferVarsInferer);
REGISTER_OP_CPU_KERNEL(
index_select,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
index_select_grad,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::IndexSelectGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/index_select_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, typename IndexT>
__global__ void index_select_cuda_kernel(const T* input, T* output,
const IndexT* index, int64_t N,
int64_t stride, int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}
template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
const IndexT* index, int64_t nums,
int64_t N, int64_t stride,
int64_t size, int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
paddle::platform::CudaAtomicAdd(&input_grad[input_idx], output_grad[idx]);
}
template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}
template <typename DeviceContext, typename T>
class IndexSelectCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* index = context.Input<LoDTensor>("Index");
auto* out = context.Output<LoDTensor>("Out");
int dim = context.Attr<int>("dim");
auto input_dim = in->dims();
auto output_dim = out->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto* in_data = in->data<T>();
auto* out_data = out->mutable_data<T>(context.GetPlace());
int64_t numel = out->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_data, out_data, index_data,
numel, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
const int* index_data = index->data<int>();
index_select_cuda_kernel<T, int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) /
PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
platform::GpuStreamSync(stream);
}
}
};
template <typename DeviceContext, typename T>
class IndexSelectGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* in_grad = context.Output<LoDTensor>(framework::GradVarName("X"));
auto* index = context.Input<LoDTensor>("Index");
auto* output_grad_data = output_grad->data<T>();
auto* in_grad_data = in_grad->mutable_data<T>(context.GetPlace());
int dim = context.Attr<int>("dim");
auto input_dim = in_grad->dims();
auto output_dim = output_grad->dims();
dim = dim >= 0 ? dim : dim + input_dim.size();
auto stride_dim = phi::stride(input_dim);
int64_t stride = stride_dim[dim];
int64_t size = output_dim[dim];
int64_t delta = input_dim[dim] - size;
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT64 ||
index_type == framework::proto::VarType::INT32;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
int64_t numel = in_grad->numel();
int64_t index_nums = index->numel();
int64_t out_nums = output_grad->numel();
auto stream =
context.template device_context<platform::CUDADeviceContext>().stream();
index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(in_grad_data, numel);
if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums,
out_nums, stride, size, delta);
platform::GpuStreamSync(stream);
} else {
const int* index_data = index->data<int>();
index_select_grad_cuda_kernel<T, int><<<
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS, 0, stream>>>(output_grad_data, in_grad_data,
index_data, index_nums,
out_nums, stride, size, delta);
platform::GpuStreamSync(stream);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
index_select,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
index_select_grad,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::IndexSelectGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
...@@ -91,41 +91,6 @@ void IndexSelectInner(const framework::ExecutionContext& context, ...@@ -91,41 +91,6 @@ void IndexSelectInner(const framework::ExecutionContext& context,
output->Resize(output_dim); output->Resize(output_dim);
} }
template <typename DeviceContext, typename T>
class IndexSelectKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto inputs = *context.Input<framework::LoDTensor>("X");
auto* index = context.Input<framework::LoDTensor>("Index");
auto* output = context.Output<framework::LoDTensor>("Out");
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += inputs.dims().size();
}
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectInner<DeviceContext, T, int>(context, &inputs, *index, output,
dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectInner<DeviceContext, T, int64_t>(context, &inputs, *index,
output, dim);
}
}
};
template <typename DeviceContext, typename T, class Enable = void> template <typename DeviceContext, typename T, class Enable = void>
struct IndexSelectAdd { struct IndexSelectAdd {
void operator()(const framework::ExecutionContext& ctx, int slice_size, void operator()(const framework::ExecutionContext& ctx, int slice_size,
...@@ -197,43 +162,5 @@ void IndexSelectGradInner(const framework::ExecutionContext& context, ...@@ -197,43 +162,5 @@ void IndexSelectGradInner(const framework::ExecutionContext& context,
x_grad->Resize(output_dim); x_grad->Resize(output_dim);
} }
template <typename DeviceContext, typename T>
class IndexSelectGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x_grad =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto* index = context.Input<framework::LoDTensor>("Index");
auto* out_grad =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
int dim = context.Attr<int>("dim");
if (dim < 0) {
dim += out_grad->dims().size();
}
const auto& index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match, true,
platform::errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
if (index_type == framework::proto::VarType::INT32) {
IndexSelectGradInner<DeviceContext, T, int>(context, *out_grad, *index,
x_grad, dim);
} else if (index_type == framework::proto::VarType::INT64) {
IndexSelectGradInner<DeviceContext, T, int64_t>(context, *out_grad,
*index, x_grad, dim);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,12 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/index_select_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class IndexSelectNPUKernel : public framework::OpKernel<T> { class IndexSelectNPUKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -221,7 +221,7 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -221,7 +221,7 @@ class LRNOp : public framework::OperatorWithKernel {
auto ar = paddle::framework::AttrReader(attrs); auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format"); const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format); auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for lrn
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) { if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
......
此差异已折叠。
...@@ -16,7 +16,8 @@ limitations under the License. */ ...@@ -16,7 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/lu_op.h" #include "paddle/fluid/operators/lu_op.h"
#include "paddle/fluid/operators/tril_triu_op.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/tril_triu_compute.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -87,7 +88,8 @@ class LU_UnpackGradKernel : public framework::OpKernel<T> { ...@@ -87,7 +88,8 @@ class LU_UnpackGradKernel : public framework::OpKernel<T> {
auto W = ldims[ldims.size() - 1]; auto W = ldims[ldims.size() - 1];
auto L_dataptr = dl_tril.mutable_data<T>(dev_ctx.GetPlace()); auto L_dataptr = dl_tril.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> l_for_range(dev_ctx, dl->numel()); platform::ForRange<DeviceContext> l_for_range(dev_ctx, dl->numel());
TrilTriuCompute<T> tril_computer(dl->data<T>(), -1, true, H, W, L_dataptr); phi::funcs::TrilTriuCompute<T> tril_computer(dl->data<T>(), -1, true, H, W,
L_dataptr);
l_for_range(tril_computer); l_for_range(tril_computer);
const auto udims = du->dims(); const auto udims = du->dims();
...@@ -96,7 +98,8 @@ class LU_UnpackGradKernel : public framework::OpKernel<T> { ...@@ -96,7 +98,8 @@ class LU_UnpackGradKernel : public framework::OpKernel<T> {
W = udims[udims.size() - 1]; W = udims[udims.size() - 1];
auto U_dataptr = du_triu.mutable_data<T>(dev_ctx.GetPlace()); auto U_dataptr = du_triu.mutable_data<T>(dev_ctx.GetPlace());
platform::ForRange<DeviceContext> u_for_range(dev_ctx, du->numel()); platform::ForRange<DeviceContext> u_for_range(dev_ctx, du->numel());
TrilTriuCompute<T> triu_computer(du->data<T>(), 0, false, H, W, U_dataptr); phi::funcs::TrilTriuCompute<T> triu_computer(du->data<T>(), 0, false, H, W,
U_dataptr);
u_for_range(triu_computer); u_for_range(triu_computer);
auto xdims = dx->dims(); auto xdims = dx->dims();
......
...@@ -50,14 +50,9 @@ class PReluMKLDNNHandler ...@@ -50,14 +50,9 @@ class PReluMKLDNNHandler
if (weights->dims().size() != x->dims().size()) { if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1); auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") { if (mode == "channel") {
if (data_format == "NHWC") {
new_weights_dims[x->dims().size() - 1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
} else {
new_weights_dims[1] = new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end()); *std::max_element(weights_dims.begin(), weights_dims.end());
} }
}
weights_dims = std::move(new_weights_dims); weights_dims = std::move(new_weights_dims);
} }
auto weights_md = memory::desc(weights_dims, MKLDNNGetDataType<T>(), auto weights_md = memory::desc(weights_dims, MKLDNNGetDataType<T>(),
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册