提交 3ebc0f73 编写于 作者: W wangruting

fix_conflict

...@@ -7,16 +7,20 @@ set(XPU_PROJECT "extern_xpu") ...@@ -7,16 +7,20 @@ set(XPU_PROJECT "extern_xpu")
set(XPU_API_LIB_NAME "libxpuapi.so") set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so") set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_BASE_DATE "20230114")
set(XPU_XCCL_BASE_VERSION "1.0.7")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20230110") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/${XPU_BASE_DATE}")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
set(XPU_XCCL_BASE_URL set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.6") "https://klx-sdk-release-public.su.bcebos.com/xccl/release/${XPU_XCCL_BASE_VERSION}"
)
if(WITH_AARCH64) if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64") set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
......
...@@ -321,8 +321,7 @@ endif() ...@@ -321,8 +321,7 @@ endif()
if(WITH_GPU) if(WITH_GPU)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0 if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0
OR (${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.6 OR (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.6))
AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.8))
include(external/cub) # download cub include(external/cub) # download cub
list(APPEND third_party_deps extern_cub) list(APPEND third_party_deps extern_cub)
endif() endif()
......
...@@ -2,18 +2,26 @@ add_subdirectory(auto_parallel) ...@@ -2,18 +2,26 @@ add_subdirectory(auto_parallel)
add_subdirectory(collective) add_subdirectory(collective)
add_subdirectory(fleet_executor) add_subdirectory(fleet_executor)
if(WITH_PYTHON) if(WITH_PYTHON)
py_proto_compile(pslib_py_proto SRCS ps.proto)
py_proto_compile(ps_py_proto SRCS the_one_ps.proto) py_proto_compile(ps_py_proto SRCS the_one_ps.proto)
add_custom_target( add_custom_target(
ps_py_proto_init ALL ps_py_proto_init ALL
COMMAND ${CMAKE_COMMAND} -E make_directory COMMAND ${CMAKE_COMMAND} -E make_directory
${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto) ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto)
add_dependencies(ps_py_proto ps_py_proto_init) add_dependencies(ps_py_proto ps_py_proto_init)
set(PSLIB_PROTO_DSTPATH
"${PADDLE_SOURCE_DIR}/python/paddle/fluid/incubate/fleet/parameter_server/pslib/"
)
if(NOT WIN32) if(NOT WIN32)
add_custom_command( add_custom_command(
TARGET ps_py_proto TARGET ps_py_proto
POST_BUILD POST_BUILD
COMMAND mv the_one_ps_pb2.py COMMAND mv the_one_ps_pb2.py
${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto/) ${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto/)
add_custom_command(
TARGET pslib_py_proto
POST_BUILD
COMMAND mv ps_pb2.py "${PSLIB_PROTO_DSTPATH}")
else() else()
string( string(
REPLACE "/" "\\" fleet_proto_dstpath REPLACE "/" "\\" fleet_proto_dstpath
...@@ -25,7 +33,15 @@ if(WITH_PYTHON) ...@@ -25,7 +33,15 @@ if(WITH_PYTHON)
COMMENT COMMENT
"Copy generated python the_one_ps_pb2 into directory ${fleet_proto_dstpath}." "Copy generated python the_one_ps_pb2 into directory ${fleet_proto_dstpath}."
) )
string(REPLACE "/" "\\" PSLIB_PROTO_DSTPATH "${PSLIB_PROTO_DSTPATH}")
add_custom_command(
TARGET pslib_py_proto
POST_BUILD
COMMAND copy /Y ps_pb2.py ${PSLIB_PROTO_DSTPATH})
endif() endif()
message(
STATUS
"Copy generated python ps_pb2.py into directory ${PSLIB_PROTO_DSTPATH}")
endif() endif()
if(WITH_RPC) if(WITH_RPC)
......
...@@ -352,41 +352,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce( ...@@ -352,41 +352,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input, const phi::DenseTensor& input,
BKCLContext_t comm, BKCLContext_t comm,
const XPUStream& stream) { const XPUStream& stream) {
phi::DenseTensor output_t; return bkcl_reduce(comm,
paddle::framework::TensorCopy(*output, platform::XPUPlace(), &output_t);
const auto& place = input.place();
auto* calc_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
switch (input.dtype()) {
case phi::DataType::FLOAT32:
calc_ctx->template Alloc<float>(&output_t);
break;
case phi::DataType::FLOAT16:
calc_ctx->template Alloc<float16>(&output_t);
break;
case phi::DataType::INT32:
calc_ctx->template Alloc<int>(&output_t);
break;
default:
VLOG(0) << "Error: type " << input.dtype() << " not supported for "
<< GetBackendName();
break;
}
int ret =
bkcl_all_reduce(comm,
input.data(), input.data(),
output_t.data(), output->data(),
input.numel(), input.numel(),
platform::ToBKCLDataType( platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())), framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op), ToBKCLRedType(opts.reduce_op),
opts.root_rank,
stream); stream);
if (rank_ == opts.root_rank) {
*output = output_t;
}
return ret;
}, },
CommType::ALLREDUCE, CommType::REDUCE,
sync_op, sync_op,
use_calc_stream); use_calc_stream);
} }
......
...@@ -36,6 +36,7 @@ cc_library( ...@@ -36,6 +36,7 @@ cc_library(
interceptor.cc interceptor.cc
compute_interceptor.cc compute_interceptor.cc
amplifier_interceptor.cc amplifier_interceptor.cc
cond_interceptor.cc
source_interceptor.cc source_interceptor.cc
sink_interceptor.cc sink_interceptor.cc
message_service.cc message_service.cc
...@@ -66,6 +67,8 @@ if(WITH_DISTRIBUTE) ...@@ -66,6 +67,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties( set_source_files_properties(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS}) ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties( set_source_files_properties(
......
...@@ -33,6 +33,7 @@ USE_INTERCEPTOR(Source); ...@@ -33,6 +33,7 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink); USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond);
void Carrier::Init( void Carrier::Init(
int64_t rank, int64_t rank,
...@@ -96,18 +97,18 @@ void Carrier::CopyParameters( ...@@ -96,18 +97,18 @@ void Carrier::CopyParameters(
int microbatch_id, int microbatch_id,
const framework::ProgramDesc& program, const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) { const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);
std::map<std::string, int> inference_root_scope_var_map; std::map<std::string, int> inference_root_scope_var_map;
for (auto var_name : inference_root_scope_vars) { for (auto var_name : inference_root_scope_vars) {
inference_root_scope_var_map.insert({var_name, 1}); inference_root_scope_var_map.insert({var_name, 1});
} }
for (auto& var : global_block.AllVars()) { for (size_t i = 0; i < program.Size(); ++i) {
for (auto& var : program.Block(i).AllVars()) {
std::string var_name = var->Name(); std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) != bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end(); inference_root_scope_var_map.end();
if (force_root) { if (force_root) {
VLOG(4) << var_name << " will be forced to be created in the root scope."; VLOG(4) << var_name
<< " will be forced to be created in the root scope.";
} }
if ((var->Persistable() || force_root) && microbatch_id == 0) { if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name()); auto* ptr = root_scope_->Var(var->Name());
...@@ -121,6 +122,7 @@ void Carrier::CopyParameters( ...@@ -121,6 +122,7 @@ void Carrier::CopyParameters(
InitializeVariable(ptr, var->GetType()); InitializeVariable(ptr, var->GetType());
} }
} }
}
} }
bool Carrier::EnqueueInterceptorMessage( bool Carrier::EnqueueInterceptorMessage(
......
...@@ -125,6 +125,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() { ...@@ -125,6 +125,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg; InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY); ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_ VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id << " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_; << " in scope: " << cur_scope_id_;
...@@ -152,6 +153,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { ...@@ -152,6 +153,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
InterceptorMessage reply_msg; InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS); reply_msg.set_message_type(DATA_IS_USELESS);
reply_msg.set_scope_idx(cur_scope_id_);
Send(up_id, reply_msg); Send(up_id, reply_msg);
} }
} }
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void CondInterceptor::PrepareDeps() {
auto& upstream = node_->upstream();
auto& downstream = node_->downstream();
auto& id_to_dep_type = node_->id_to_dep_type();
for (const auto& up : upstream) {
if (id_to_dep_type.at(up.first) == DependType::NORMAL) {
normal_in_id_.insert(up.first);
}
}
for (const auto& down : downstream) {
if (id_to_dep_type.at(down.first) == DependType::NORMAL) {
normal_out_id_.insert(down.first);
} else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) {
stop_loop_id_ = down.first;
}
}
}
bool CondInterceptor::GetCondResult() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* cond_var =
microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var());
PADDLE_ENFORCE(cond_var,
platform::errors::NotFound(
"Condition variable %s not exists in scope %ld",
node_->cond_var(),
cur_scope_id_));
const auto& cond_tensor = cond_var->Get<phi::DenseTensor>();
bool res = false;
if (platform::is_gpu_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_USELESS);
ready_msg.set_scope_idx(cur_scope_id_);
Send(up_id, ready_msg);
}
void CondInterceptor::Compute() {
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_;
for (auto& down_id : normal_out_id_) {
SendDataReady(down_id);
}
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ready_queue_.push(msg.scope_idx());
Compute();
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) {
ReplyDataIsUseless(up_id);
}
// Gc the variable in while block
int64_t scope_id = msg.scope_idx();
if (gc_) {
VLOG(3) << "Release vars in while block in scope " << scope_id;
framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id],
node_->while_block_vars(),
gc_.get());
}
}
}
}
REGISTER_INTERCEPTOR(Cond, CondInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class CondInterceptor final : public Interceptor {
public:
CondInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void PrepareDeps();
void Run(const InterceptorMessage& msg);
void Compute();
bool GetCondResult();
void SendDataReady(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_;
std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_;
};
} // namespace distributed
} // namespace paddle
...@@ -66,12 +66,11 @@ void FleetExecutor::Init( ...@@ -66,12 +66,11 @@ void FleetExecutor::Init(
"Fleet executor is inited with empty task node")); "Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph // TODO(fleet_exe devs): the unused_vars should be got from run time graph
std::vector<std::unique_ptr<framework::OperatorBase>> ops; std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto task_node : task_nodes) { for (const auto& desc : program_desc.Block(0).AllOps()) {
for (auto op : task_node->ops()) { ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
} }
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
// NOTE: For inference, the vars in inference_root_scope_vars // NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the // shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result. // inf. If they are GCed, it will cause error during ZeroCopy the result.
...@@ -107,6 +106,25 @@ void FleetExecutor::Init( ...@@ -107,6 +106,25 @@ void FleetExecutor::Init(
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task; std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars); task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars;
std::vector<std::string> vars_in_parent;
std::vector<std::string> vars_in_sub;
for (auto& var : program_desc.Block(0).AllVars()) {
vars_in_parent.emplace_back(var->Name());
}
for (auto& var : program_desc.Block(1).AllVars()) {
vars_in_sub.emplace_back(var->Name());
}
std::sort(vars_in_parent.begin(), vars_in_parent.end());
std::sort(vars_in_sub.begin(), vars_in_sub.end());
std::set_difference(vars_in_sub.begin(),
vars_in_sub.end(),
vars_in_parent.begin(),
vars_in_parent.end(),
std::back_inserter(while_block_vars));
task_node->SetWhileBlockVars(while_block_vars);
}
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
} }
......
...@@ -24,33 +24,14 @@ namespace { ...@@ -24,33 +24,14 @@ namespace {
using OperatorBase = TaskNode::OperatorBase; using OperatorBase = TaskNode::OperatorBase;
} }
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: program_(program), : program_(program),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {
max_slot_nums_(max_slot_nums) {
// TODO(liyurui): Will be removed when execute program is supported. // TODO(liyurui): Will be removed when execute program is supported.
Init(); Init();
} }
...@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, ...@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank) TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) { : program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1; max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO) LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: " << "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank << rank
...@@ -98,13 +78,11 @@ TaskNode::TaskNode(int32_t role, ...@@ -98,13 +78,11 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: role_(role), : role_(role),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {
max_slot_nums_(max_slot_nums) {
if (op_descs.empty()) { if (op_descs.empty()) {
return; return;
} }
...@@ -121,33 +99,35 @@ TaskNode::TaskNode(int32_t role, ...@@ -121,33 +99,35 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops, const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: ops_(ops), : ops_(ops),
role_(role), role_(role),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {}
max_slot_nums_(max_slot_nums) {}
TaskNode::TaskNode(int32_t role, TaskNode::TaskNode(int32_t role,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times)
int64_t max_slot_nums)
: role_(role), : role_(role),
rank_(rank), rank_(rank),
task_id_(task_id), task_id_(task_id),
max_run_times_(max_run_times), max_run_times_(max_run_times) {}
max_slot_nums_(max_slot_nums) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) { bool TaskNode::AddUpstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = upstream_.emplace(task_id, buff_size); const auto& ret = upstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second; return ret.second;
} }
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) { bool TaskNode::AddDownstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = downstream_.emplace(task_id, buff_size); const auto& ret = downstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second; return ret.second;
} }
......
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -29,38 +31,30 @@ class OpDesc; ...@@ -29,38 +31,30 @@ class OpDesc;
} // namespace framework } // namespace framework
namespace distributed { namespace distributed {
enum class DependType { NORMAL, LOOP, STOP_LOOP };
class TaskNode final { class TaskNode final {
public: public:
using OperatorBase = paddle::framework::OperatorBase; using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times); TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role, TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times);
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs, const std::vector<framework::OpDesc*>& op_descs,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times);
int64_t max_slot_nums);
TaskNode(int32_t role, TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops, const std::vector<framework::OperatorBase*>& ops,
int64_t rank, int64_t rank,
int64_t task_id, int64_t task_id,
int64_t max_run_times, int64_t max_run_times);
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank); TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
// TODO(liyurui): This will be the only constructor for task node // TODO(liyurui): This will be the only constructor for task node
TaskNode(paddle::framework::ProgramDesc* program, TaskNode(paddle::framework::ProgramDesc* program,
int64_t task_id, int64_t task_id,
int64_t rank, int64_t rank,
int64_t max_run_times, int64_t max_run_times);
int64_t max_slot_nums);
~TaskNode() = default; ~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program); void SetProgram(paddle::framework::ProgramDesc* program);
...@@ -69,11 +63,11 @@ class TaskNode final { ...@@ -69,11 +63,11 @@ class TaskNode final {
int64_t task_id() const { return task_id_; } int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; } int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; } int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; } int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; } int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; } int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; } int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::string& cond_var() const { return cond_var_; }
const std::unordered_map<int64_t, int64_t>& upstream() const { const std::unordered_map<int64_t, int64_t>& upstream() const {
return upstream_; return upstream_;
} }
...@@ -86,11 +80,20 @@ class TaskNode final { ...@@ -86,11 +80,20 @@ class TaskNode final {
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const { const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_; return ops_vec_;
} }
const std::unordered_map<int64_t, DependType> id_to_dep_type() const {
return id_to_dep_type_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>& const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const { unused_vars() const {
return unused_vars_; return unused_vars_;
} }
const std::vector<std::string> while_block_vars() const {
return while_block_vars_;
}
void SetCondVarName(const std::string& cond_var_name) {
cond_var_ = cond_var_name;
}
void SetRunPerSteps(int64_t value); void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value); void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value); void SetReplyUpPerSteps(int64_t value);
...@@ -101,10 +104,17 @@ class TaskNode final { ...@@ -101,10 +104,17 @@ class TaskNode final {
unused_vars) { unused_vars) {
unused_vars_ = unused_vars; unused_vars_ = unused_vars;
} }
void SetWhileBlockVars(const std::vector<std::string>& vars) {
while_block_vars_ = vars;
}
// upstream need buffs? // upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1); bool AddUpstreamTask(int64_t task_id,
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1); int64_t buff_size = 1,
DependType type = DependType::NORMAL);
bool AddDownstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
std::string DebugString() const; std::string DebugString() const;
private: private:
...@@ -115,16 +125,20 @@ class TaskNode final { ...@@ -115,16 +125,20 @@ class TaskNode final {
// task_id-->buff_size // task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_; std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_; std::unordered_map<int64_t, int64_t> downstream_;
// task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_; framework::ProgramDesc* program_;
std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_; std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>> std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_; unused_vars_;
std::vector<std::string> while_block_vars_;
int32_t role_; int32_t role_;
int64_t rank_; int64_t rank_;
int64_t task_id_; int64_t task_id_;
int64_t max_run_times_; int64_t max_run_times_;
int64_t max_slot_nums_;
int64_t run_per_steps_{1}; int64_t run_per_steps_{1};
int64_t run_at_offset_{0}; int64_t run_at_offset_{0};
......
...@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) { ...@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
// FIXME: don't delete, otherwise interceptor will use undefined node // FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a = TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2); // role, ops, rank, task_id
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 2);
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, 2); TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// source->a->b->sink // source->a->b->sink
......
...@@ -37,8 +37,8 @@ TEST(ComputeInterceptor, Compute) { ...@@ -37,8 +37,8 @@ TEST(ComputeInterceptor, Compute) {
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); TaskNode* node_a = new TaskNode(0, 0, 0, 3);
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* sink = new TaskNode(0, SINK_ID, 3); TaskNode* sink = new TaskNode(0, SINK_ID, 3);
// source->a->b->sink // source->a->b->sink
......
...@@ -71,12 +71,12 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -71,12 +71,12 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 1); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0); TaskNode* node_b = new TaskNode(0, 0, 1, 1);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 1);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0); TaskNode* node_d = new TaskNode(0, 0, 3, 1);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0); TaskNode* node_e = new TaskNode(0, 0, 4, 1);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0); TaskNode* node_f = new TaskNode(0, 0, 5, 1);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->e->f->sink // source->a->b->c->d->e->f->sink
......
...@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) { ...@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); TaskNode* node_c = new TaskNode(0, 0, 2, 3);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->sink // source->a->b->c->d->sink
......
...@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) { ...@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1); source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1); node_a->AddUpstreamTask(SOURCE_ID, 1);
......
...@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) { ...@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node // NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source = TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
source->AddDownstreamTask(0, 1); source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1); node_a->AddUpstreamTask(SOURCE_ID, 1);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle;
option cc_generic_services = true;
option cc_enable_arenas=true;
message PSParameter {
optional string worker_class = 1;
optional string server_class = 2;
optional string instance_class = 3;
optional string init_gflags = 4 [ default = "" ];
optional WorkerParameter worker_param = 101;
optional ServerParameter server_param = 102;
repeated DownpourTrainerParameter trainer_param = 301;
optional FsClientParameter fs_client_param = 501;
}
message WorkerParameter {
optional DownpourWorkerParameter downpour_worker_param = 1;
}
message ServerParameter {
optional DownpourServerParameter downpour_server_param = 1;
}
message DownpourWorkerParameter {
repeated TableParameter downpour_table_param = 1;
}
message DownpourTrainerParameter {
repeated DenseTableParameter dense_table = 1;
repeated SparseTableParameter sparse_table = 2;
optional int32 push_sparse_per_batch = 3;
optional int32 push_dense_per_batch = 4;
repeated string skip_op = 5;
repeated ProgramConfig program_config = 6;
}
message ProgramConfig {
required string program_id = 1;
repeated int32 push_sparse_table_id = 2;
repeated int32 push_dense_table_id = 3;
repeated int32 pull_sparse_table_id = 4;
repeated int32 pull_dense_table_id = 5;
}
message DenseTableParameter {
optional int32 table_id = 1;
repeated string dense_variable_name = 2;
repeated string dense_gradient_variable_name = 3;
optional int32 fea_dim = 4;
}
message SparseTableParameter {
optional int32 table_id = 1;
optional int32 feature_dim = 2;
repeated string slot_key = 3;
repeated string slot_value = 4;
repeated string slot_gradient = 5;
}
message DownpourServerParameter {
repeated TableParameter downpour_table_param = 1;
optional ServerServiceParameter service_param = 2;
}
message ServerServiceParameter {
optional string server_class = 1 [ default = "DownpourBrpcPsServer" ];
optional string client_class = 2 [ default = "DownpourBrpcPsClient" ];
optional string service_class = 3 [ default = "DownpourPsService"];
optional uint32 start_server_port = 4 [ default = 0 ]; //will find a avaliable port from it
optional uint32 server_thread_num = 5 [ default = 12 ];
}
enum TableType {
PS_SPARSE_TABLE = 0;
PS_DENSE_TABLE = 1;
}
message TableParameter {
optional uint64 table_id = 1;
optional string table_class = 2;
optional uint64 shard_num = 3 [ default = 1000 ];
optional TableAccessorParameter accessor = 4;
optional TableType type = 5;
optional bool compress_in_save = 6 [default = false];
//for cache model
optional bool enable_sparse_table_cache = 7 [default = true];
optional double sparse_table_cache_rate = 8 [default = 0.00055];
optional uint32 sparse_table_cache_file_num = 9 [default = 16];
optional double sparse_table_mem_cache_rate = 10 [default = 0.5];
}
message TableAccessorParameter {
optional string accessor_class = 1;
optional SparseSGDRuleParameter sparse_sgd_param = 2;
optional DenseSGDRuleParameter dense_sgd_param = 3;
optional uint32 fea_dim = 4 [default = 11];
optional uint32 embedx_dim = 5 [default = 8];
optional uint32 embedx_threshold = 6 [default = 10];
optional DownpourTableAccessorParameter downpour_accessor_param = 7;
repeated TableAccessorSaveParameter table_accessor_save_param = 8;
optional SparseCommonSGDRuleParameter sparse_commonsgd_param = 9;
optional SparseCommonSGDRuleParameter embed_sgd_param = 10;
optional SparseCommonSGDRuleParameter embedx_sgd_param = 11;
}
message DownpourTableAccessorParameter {
optional float nonclk_coeff = 1 [default = 0.1]; // to calculate show_click_score
optional float click_coeff = 2 [default = 1]; // to calculate show_click_score
optional float base_threshold = 3 [default = 1.5]; // show_click_score > base_threshold, this feature can be saved
optional float delta_threshold = 4 [default = 0.25]; // delta_score > delta_threshold, this feature can be saved
optional float delta_keep_days = 5 [default = 16]; // unseen_day < delta_keep_days, this feature can be saved
optional float show_click_decay_rate = 6 [default = 0.98]; // show/click will update to show/click * show_click_decay_rate after a day
optional float delete_threshold = 7 [default = 0.8]; // threshold to shrink a feasign
optional float delete_after_unseen_days = 8 [default = 30]; // unseen_day > delete_after_unseen_days, this feature will be delete in shrink_model
optional int32 ssd_unseenday_threshold = 9 [default = 1]; // threshold to save ssd
}
message TableAccessorSaveParameter {
optional uint32 param = 1;
optional string converter = 2;
optional string deconverter = 3;
}
enum PsCmdID {
PS_PULL_DENSE_TABLE = 0;
PS_PUSH_DENSE_TABLE = 1;
PS_PULL_SPARSE_TABLE = 2;
PS_PUSH_SPARSE_TABLE = 3;
PS_SHRINK_TABLE = 4;
PS_SAVE_ONE_TABLE = 5;
PS_SAVE_ALL_TABLE = 6;
PS_LOAD_ONE_TABLE = 7;
PS_LOAD_ALL_TABLE = 8;
PS_CLEAR_ONE_TABLE = 9;
PS_CLEAR_ALL_TABLE = 10;
PS_PUSH_DENSE_PARAM = 11;
PS_STOP_SERVER = 12;
PS_SAVE_ONE_CACHE_TABLE = 13;
PS_GET_CACHE_THRESHOLD = 14;
PS_CACHE_SHUFFLE = 15;
PS_COPY_TABLE = 16;
PS_COPY_TABLE_BY_FEASIGN = 17;
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
PS_PRINT_TABLE_STAT = 20;
PS_SAVE_ONE_TABLE_PREFIX = 21;
PS_SAVE_MEM_CACHE_TABLE = 22;
//pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
//local_client2local_client cmd start from 200
PS_C2C_PULL_SPARSE_TABLE = 201;
}
message PsRequestMessage {
required uint32 cmd_id = 1;
optional uint32 table_id = 2;
repeated bytes params = 3;
optional int32 client_id = 4;
optional bytes data = 5;
};
message SparseSGDRuleParameter {
optional double learning_rate = 1 [default = 0.05];
optional double initial_g2sum = 2 [default = 3.0];
optional double initial_range = 3 [default = 0.0001];
repeated float weight_bounds = 4;
}
message SparseCommonSGDRuleParameter {
optional string name = 1;
optional SparseNaiveSGDRuleParameter naive = 2;
optional SparseAdagradSGDRuleParameter adagrad = 3;
optional SparseAdamSGDParameter adam = 4;
}
message SparseNaiveSGDRuleParameter {
optional double learning_rate = 1 [default = 0.05];
optional double initial_range = 2 [default = 0.0001];
repeated float weight_bounds = 3;
}
message SparseAdagradSGDRuleParameter {
optional double learning_rate = 1 [default = 0.05];
optional double initial_g2sum = 2 [default = 3.0];
optional double initial_range = 3 [default = 0.0001];
repeated float weight_bounds = 4;
}
message SparseAdamSGDParameter {
optional double learning_rate = 1 [default = 0.001];
optional double initial_range = 2 [default = 0.0001];
optional double beta1_decay_rate = 3 [default = 0.9];
optional double beta2_decay_rate = 4 [default = 0.999];
optional double ada_epsilon = 5 [default = 1e-08];
repeated float weight_bounds = 6;
}
message DenseSGDRuleParameter {
optional string name = 1;
optional AdamSGDParameter adam = 2;
optional NaiveSGDParameter naive = 3;
optional SummarySGDParameter summary = 4;
optional MovingAverageRuleParameter moving_average = 5;
}
message AdamSGDParameter {
optional double learning_rate = 1 [default = 5e-06]; // \u5B66\u4E60\u7387
optional double avg_decay_rate = 2 [default = 0.999993]; // avg_weight\u7684\u8870\u51CF\u7CFB\u6570
optional double ada_decay_rate = 3 [default = 0.9999];
optional double ada_epsilon = 4 [default = 1e-08];
optional double mom_decay_rate = 5 [default = 0.99];
}
message NaiveSGDParameter {
optional double learning_rate = 1 [default = 0.0002];
optional double avg_decay_rate = 2;
}
message SummarySGDParameter {
optional double summary_decay_rate = 1 [default = 0.999999]; // \u6743\u91CD\u7684\u8870\u51CF\u7CFB\u6570
}
message MovingAverageRuleParameter {
optional double momentum = 1;
}
message PsResponseMessage {
required int32 err_code = 1 [default = 0];
required string err_msg = 2 [default = ""];
optional bytes data = 3;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
};
message FsClientParameter {
enum FsApiType {
HDFS = 0;
AFS = 1;
}
optional FsApiType fs_type = 1 [default = HDFS];
optional string uri = 2; //such as afs://tianqi.afs.baidu.com:9902
optional string user = 3; //user_name to access fs
optional string passwd = 4; //password
optional int32 buffer_size = 5; //buffer for read/write
optional string hadoop_bin = 51;
optional string afs_conf = 101;
}
...@@ -1839,9 +1839,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1839,9 +1839,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
False if self.composite_func_info == {} else True False if self.composite_func_info == {} else True
) )
if is_composite_grad_api: if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f""" next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{next_grad_node_creation_str} {next_grad_node_creation_str}
}} }}
""" """
...@@ -1982,6 +1982,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -1982,6 +1982,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list backward_attrs_list = self.backward_attrs_list
backward_inplace_map = self.backward_inplace_map backward_inplace_map = self.backward_inplace_map
indent = GetIndent(1) indent = GetIndent(1)
need_gen_trace_backard_for_inplace = False
# Construct grad_api function args # Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes # Order: TensorWrappers, GradTensors, Attributes
...@@ -2211,6 +2212,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2211,6 +2212,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
}} else {{ }} else {{
{inplace_str} {inplace_str}
}}""" }}"""
need_gen_trace_backard_for_inplace = True
else: else:
inplace_for_grad_outs_str += inplace_str inplace_for_grad_outs_str += inplace_str
...@@ -2259,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2259,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future. # TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api: elif is_composite_grad_api:
grad_function_call_str = f""" grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{ if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called "; VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{ }}else{{
...@@ -2282,7 +2284,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ...@@ -2282,7 +2284,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if ( if (
len(next_grad_node_creation_str) > 0 len(next_grad_node_creation_str) > 0
or is_invoke_forward_api or is_invoke_forward_api
or inplace_for_grad_outs_str != '' or need_gen_trace_backard_for_inplace
): ):
compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n" compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
......
...@@ -618,7 +618,8 @@ if(WITH_PYTHON) ...@@ -618,7 +618,8 @@ if(WITH_PYTHON)
fleet_proto_init fleet_proto_init
pass_desc_py_proto pass_desc_py_proto
ps_py_proto ps_py_proto
ps_py_proto_init) ps_py_proto_init
pslib_py_proto)
if(NOT WIN32) if(NOT WIN32)
add_custom_command( add_custom_command(
TARGET framework_py_proto TARGET framework_py_proto
......
...@@ -140,7 +140,116 @@ struct FusedAttentionGradPattern : public PatternBase { ...@@ -140,7 +140,116 @@ struct FusedAttentionGradPattern : public PatternBase {
bool do_dropout, // dropout the softmax(qk) or not bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not bool add_residual); // add residual to out linear or not
// TODO(Yuang Liu): add backward pattern // post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op);
PATTERN_DECL_NODE(post_layer_norm_grad_scale);
PATTERN_DECL_NODE(post_layer_norm_grad_bias);
PATTERN_DECL_NODE(post_layer_norm_grad_mean);
PATTERN_DECL_NODE(post_layer_norm_grad_variance);
PATTERN_DECL_NODE(post_layer_norm_grad_x);
PATTERN_DECL_NODE(post_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_x_grad);
// residual grad
PATTERN_DECL_NODE(residual_ele_add_grad_op);
PATTERN_DECL_NODE(residual_ele_add_grad_x);
PATTERN_DECL_NODE(residual_ele_add_grad_bias);
PATTERN_DECL_NODE(residual_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(residual_ele_add_grad_x_grad);
// out linear grad
PATTERN_DECL_NODE(out_linear_dropout_grad_op);
PATTERN_DECL_NODE(out_linear_dropout_grad_mask);
PATTERN_DECL_NODE(out_linear_dropout_grad_out);
PATTERN_DECL_NODE(out_linear_ele_add_grad_op);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x_grad);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_op);
PATTERN_DECL_NODE(out_linear_matmul_grad_x);
PATTERN_DECL_NODE(out_linear_matmul_grad_w);
PATTERN_DECL_NODE(out_linear_matmul_grad_x_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_w_grad);
// core attention grad
PATTERN_DECL_NODE(qkv_reshape_grad_op);
PATTERN_DECL_NODE(qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(qkv_reshape_grad_out);
PATTERN_DECL_NODE(qkv_transpose_grad_op);
PATTERN_DECL_NODE(qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(qkv_transpose_grad_out);
PATTERN_DECL_NODE(qkv_matmul_grad_op);
PATTERN_DECL_NODE(qkv_matmul_grad_x);
PATTERN_DECL_NODE(qkv_matmul_grad_w);
PATTERN_DECL_NODE(qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(qkv_matmul_grad_w_grad);
PATTERN_DECL_NODE(attn_dropout_grad_op);
PATTERN_DECL_NODE(attn_dropout_grad_mask);
PATTERN_DECL_NODE(attn_dropout_grad_out);
PATTERN_DECL_NODE(qk_softmax_grad_op);
PATTERN_DECL_NODE(qk_softmax_grad_fwd_out);
PATTERN_DECL_NODE(qk_softmax_grad_out);
PATTERN_DECL_NODE(add_mask_ele_add_grad_op);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x);
PATTERN_DECL_NODE(add_mask_ele_add_grad_bias);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x_grad);
PATTERN_DECL_NODE(qk_scale_grad_op);
PATTERN_DECL_NODE(qk_scale_grad_out);
PATTERN_DECL_NODE(qk_matmul_grad_op);
PATTERN_DECL_NODE(qk_matmul_grad_x);
PATTERN_DECL_NODE(qk_matmul_grad_w);
PATTERN_DECL_NODE(qk_matmul_grad_x_grad);
PATTERN_DECL_NODE(qk_matmul_grad_w_grad);
// fuse qkv projection grad
PATTERN_DECL_NODE(fuse_qkv_split_grad_op); // concat op
PATTERN_DECL_NODE(fuse_qkv_split_grad_out);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_op);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_op);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_out);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_op);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad);
// pre layer norm grad
PATTERN_DECL_NODE(pre_layer_norm_grad_op);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias);
PATTERN_DECL_NODE(pre_layer_norm_grad_mean);
PATTERN_DECL_NODE(pre_layer_norm_grad_variance);
PATTERN_DECL_NODE(pre_layer_norm_grad_x);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_x_grad);
// grad accumulation
PATTERN_DECL_NODE(grad_accumulation_sum_op);
PATTERN_DECL_NODE(grad_accumulation_out);
}; };
} // namespace patterns } // namespace patterns
......
...@@ -53,8 +53,13 @@ void MapOp2AnotherPass::ApplyImpl(ir::Graph* graph) const { ...@@ -53,8 +53,13 @@ void MapOp2AnotherPass::ApplyImpl(ir::Graph* graph) const {
op_desc->SetAttr("shape", std::vector<int>{0, -1}); op_desc->SetAttr("shape", std::vector<int>{0, -1});
} }
} else if (op_type == "depthwise_conv2d") { } else if (op_type == "depthwise_conv2d") {
auto groups = PADDLE_GET_CONST(int, op_desc->GetAttr("groups"));
if (groups > 1) {
#if CUDNN_VERSION >= 8100
op_desc->SetType(replaced_map[op_type]); op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true); op_desc->SetAttr("use_cudnn", true);
#endif
}
} }
op_desc->Flush(); op_desc->Flush();
++found_count; ++found_count;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
namespace framework {
namespace ir {
inline std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
inline std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
inline void SetActivationAttrs(paddle::framework::OpDesc* fused_op,
paddle::framework::OpDesc* act_op,
const std::string& act_type) {
if (fused_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(PADDLE_GET_CONST(bool, fused_op->GetAttr("use_mkldnn")),
phi::errors::PreconditionNotMet(
"oneDNN activation fuses require use_mkldnn=True"));
}
fused_op->SetAttr("use_mkldnn", true);
auto attr_map = GetAttributeMap(act_type);
for (const auto& attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fused_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fused_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fused_op->SetAttr("fuse_activation", act_type);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations(); auto act_types = GetSupportedActivations();
std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"}; std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"};
for (auto& act_type : act_types) { for (auto& act_type : act_types) {
...@@ -40,7 +40,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, ...@@ -40,7 +40,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
const std::string& conv_type, const std::string& conv_type,
std::string& act_type) const { std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph); FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -62,28 +62,13 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, ...@@ -62,28 +62,13 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
OpDesc* conv_op = conv->Op(); OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();
if (conv_op->Type() == "conv2d") { if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d"); conv_op->SetType("fused_conv2d");
} }
auto attr_map = phi::funcs::GetAttributeMap(act_type); SetActivationAttrs(conv_op, activation->Op(), act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type =
PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
conv_op->SetOutput("Output", {activation_out->Name()}); conv_op->SetOutput("Output", {activation_out->Name()});
IR_NODE_LINK_TO(conv, activation_out); IR_NODE_LINK_TO(conv, activation_out);
...@@ -105,7 +90,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, ...@@ -105,7 +90,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
void ConvActivationMkldnnFusePass::FuseConvConcatAct( void ConvActivationMkldnnFusePass::FuseConvConcatAct(
Graph* graph, std::string& act_type) const { Graph* graph, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv2d_concat_" + act_type + "_mkldnn_fuse_pass", graph); FusePassBase::Init("conv2d_concat_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -137,13 +122,13 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct( ...@@ -137,13 +122,13 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
return; return;
} }
bool is_not_conv_mkldnn = bool is_not_conv_onednn =
!(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn")); !(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if ((prev_op_nodes[0]->Op()->Type() != "conv2d" && if ((prev_op_nodes[0]->Op()->Type() != "conv2d" &&
prev_op_nodes[0]->Op()->Type() != "fused_conv2d") || prev_op_nodes[0]->Op()->Type() != "fused_conv2d") ||
is_not_conv_mkldnn) { is_not_conv_onednn) {
LOG(WARNING) << "This fuse pass supports only conv2d(mkldnn) | " LOG(WARNING) << "This fuse pass supports only conv2d(oneDNN) | "
"fused_conv2d(mkldnn) + activation."; "fused_conv2d(oneDNN) + activation.";
return; return;
} }
} }
...@@ -153,23 +138,8 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct( ...@@ -153,23 +138,8 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
if (conv_op->Type() == "conv2d") { if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d"); conv_op->SetType("fused_conv2d");
} }
OpDesc* act_op = activation_op->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type); SetActivationAttrs(conv_op, activation_op->Op(), act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
act_type = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
} }
concat_op->Op()->SetOutput("Out", {activation_out->Name()}); concat_op->Op()->SetOutput("Out", {activation_out->Name()});
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include "paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -27,7 +27,7 @@ namespace ir { ...@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const { void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations(); auto act_types = GetSupportedActivations();
std::vector<std::string> elt_types = { std::vector<std::string> elt_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul"}; "elementwise_add", "elementwise_sub", "elementwise_mul"};
...@@ -42,7 +42,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( ...@@ -42,7 +42,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
const std::string &elt_type, const std::string &elt_type,
const std::string &act_type) const { const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph); FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -62,35 +62,8 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct( ...@@ -62,35 +62,8 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, elementwise_act_pattern); activation_out, activation_out, elementwise_act_pattern);
auto *elementwise_op = elementwise->Op(); SetActivationAttrs(elementwise->Op(), activation->Op(), act_type);
elementwise->Op()->SetOutput("Out", {activation_out->Name()});
if (elementwise_op->HasAttr("use_mkldnn")) {
const std::string wo_elt_type =
"The " + elt_type; // Workaround for PP error message checking.
PADDLE_ENFORCE_EQ(
PADDLE_GET_CONST(bool, elementwise_op->GetAttr("use_mkldnn")),
true,
platform::errors::PreconditionNotMet(
wo_elt_type + "+Act fusion may happen only when oneDNN library "
"is used."));
}
auto *activation_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second,
activation_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate")))
elementwise_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
elementwise_op->SetAttr("fuse_activation", act_type);
elementwise_op->SetOutput("Out", {activation_out->Name()});
IR_OP_VAR_LINK(elementwise, activation_out); IR_OP_VAR_LINK(elementwise, activation_out);
GraphSafeRemoveNodes(g, {activation, elementwise_out}); GraphSafeRemoveNodes(g, {activation, elementwise_out});
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations(); auto act_types = GetSupportedActivations();
for (auto act_type : act_types) FuseFCAct(graph, act_type); for (auto act_type : act_types) FuseFCAct(graph, act_type);
} }
...@@ -33,7 +33,7 @@ void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const { ...@@ -33,7 +33,7 @@ void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const { const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph); FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -50,35 +50,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph, ...@@ -50,35 +50,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern); GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern);
auto *fc_op = fc->Op(); SetActivationAttrs(fc->Op(), act->Op(), act_type);
auto *act_op = act->Op(); fc->Op()->SetOutput("Out", {act_out->Name()});
if (fc_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(
PADDLE_GET_CONST(bool, fc_op->GetAttr("use_mkldnn")),
platform::errors::PreconditionNotMet(
"The FC+Act fusion may happen only when oneDNN library "
"is used."));
}
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fc_op->SetAttr("fuse_activation", act_type);
}
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out); IR_OP_VAR_LINK(fc, act_out);
GraphSafeRemoveNodes(g, {act, fc_out}); GraphSafeRemoveNodes(g, {act, fc_out});
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
namespace paddle { namespace paddle {
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations(); auto act_types = GetSupportedActivations();
auto matmul_types = {"matmul", "matmul_v2"}; auto matmul_types = {"matmul", "matmul_v2"};
for (const auto& matmul_type : matmul_types) for (const auto& matmul_type : matmul_types)
...@@ -37,7 +37,7 @@ void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { ...@@ -37,7 +37,7 @@ void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
void MatmulActivationMkldnnFusePass::FuseMatmulAct( void MatmulActivationMkldnnFusePass::FuseMatmulAct(
Graph* graph, const std::string& matmul_type, std::string& act_type) const { Graph* graph, const std::string& matmul_type, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph); FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -61,24 +61,8 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct( ...@@ -61,24 +61,8 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, matmul_act_pattern); activation_out, activation_out, matmul_act_pattern);
OpDesc* matmul_op = matmul->Op(); SetActivationAttrs(matmul->Op(), activation->Op(), act_type);
OpDesc* act_op = activation->Op(); matmul->Op()->SetOutput("Out", {activation_out->Name()});
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type =
PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
}
matmul_op->SetAttr("fuse_activation", act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});
IR_NODE_LINK_TO(matmul, activation_out); IR_NODE_LINK_TO(matmul, activation_out);
GraphSafeRemoveNodes(graph, {activation, matmul_out}); GraphSafeRemoveNodes(graph, {activation, matmul_out});
......
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -27,7 +27,7 @@ namespace ir { ...@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail; using string::PrettyLogDetail;
void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const { void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations(); auto act_types = GetSupportedActivations();
// Currently softplus can't be fused with hard_sigmoid // Currently softplus can't be fused with hard_sigmoid
act_types.erase( act_types.erase(
...@@ -42,7 +42,7 @@ void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const { ...@@ -42,7 +42,7 @@ void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
void SoftplusActivationOneDNNPass::FuseSoftplusActivation( void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
Graph *graph, const std::string &act_type) const { Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("softplus_activation", graph); FusePassBase::Init("softplus_activation", graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
...@@ -63,34 +63,8 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation( ...@@ -63,34 +63,8 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
activation, activation, softplus_activation_pattern); activation, activation, softplus_activation_pattern);
auto *softplus_op = softplus->Op(); SetActivationAttrs(softplus->Op(), activation->Op(), act_type);
softplus->Op()->SetOutput("Out", {activation_out->Name()});
if (softplus_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE_EQ(
PADDLE_GET_CONST(bool, softplus_op->GetAttr("use_mkldnn")),
true,
platform::errors::PreconditionNotMet("The softplus + activation "
"fusion may happen only when "
"oneDNN library is used."));
}
auto *activation_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate")))
softplus_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
softplus_op->SetAttr("fuse_activation", act_type);
softplus_op->SetAttr("use_mkldnn", true);
softplus_op->SetOutput("Out", {activation_out->Name()});
IR_OP_VAR_LINK(softplus, activation_out); IR_OP_VAR_LINK(softplus, activation_out);
GraphSafeRemoveNodes(g, {activation, softplus_out}); GraphSafeRemoveNodes(g, {activation, softplus_out});
......
...@@ -144,8 +144,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -144,8 +144,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Input(param_name); const auto& arg_names = desc->Input(param_name);
for (const auto& arg_name : arg_names) { for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name); deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
} }
} }
...@@ -153,8 +151,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -153,8 +151,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Output(param_name); const auto& arg_names = desc->Output(param_name);
for (const auto& arg_name : arg_names) { for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name); deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
} }
} }
} }
...@@ -166,48 +162,25 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -166,48 +162,25 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
return deny_var_set; return deny_var_set;
} }
std::unordered_set<std::string> OpTransInfo::GetIgnoreInplaceVarNames( std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const OpDesc& op_desc) const { const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) {
if (!ignore_inplace_param_cond_.count(op_desc.Type())) { std::unordered_set<std::string> all_inputs, all_outputs;
return {};
}
const auto& ignore_inplace_names =
ignore_inplace_param_cond_.at(op_desc.Type());
VLOG(4) << "We found ignore inplace param "
<< GetDebugInfo(ignore_inplace_names) << " in op [" << op_desc.Type()
<< "].";
std::unordered_set<std::string> ignore_inplace_set; for (auto* var : cluster_inputs) {
for (const auto& param_name : ignore_inplace_names) { all_inputs.insert(var->Name());
if (op_desc.HasOutput(param_name)) {
const auto& arg_names = op_desc.Output(param_name);
ignore_inplace_set.insert(arg_names.begin(), arg_names.end());
} }
for (auto* var : cluster_outputs) {
all_outputs.insert(var->Name());
} }
VLOG(4) << "All ignore inplace var names are " std::unordered_set<std::string> inplace_var_set;
<< GetDebugInfo(ignore_inplace_set); for (const auto& var_name : all_inputs) {
if (all_outputs.count(var_name)) {
return ignore_inplace_set; inplace_var_set.insert(var_name);
}
bool OpTransInfo::IsInplaceOp(
const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const {
const auto& ignore_inplace_set = GetIgnoreInplaceVarNames(op_desc);
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0 && !deny_var_names.count(name) &&
!ignore_inplace_set.count(name)) {
VLOG(4) << "The argument " << name << " in op " << op_desc.Type()
<< " is a inplace op, skip!";
return true;
} }
} }
return false;
return inplace_var_set;
} }
namespace { namespace {
...@@ -503,6 +476,14 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -503,6 +476,14 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// initialize empty map for kMemOptVarInfoFromMainGraph attribute, // initialize empty map for kMemOptVarInfoFromMainGraph attribute,
// it will be filled on the share_mem_opt_info_to_subgraph pass // it will be filled on the share_mem_opt_info_to_subgraph pass
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph); subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs));
VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
inplace_var_names.release());
return subgraph; return subgraph;
} }
...@@ -594,7 +575,6 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, ...@@ -594,7 +575,6 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_outputs,
int64_t compilation_key, int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) { Graph* graph) {
// Add the cinn launch op // Add the cinn launch op
framework::OpDesc cinn_op_desc; framework::OpDesc cinn_op_desc;
...@@ -615,6 +595,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, ...@@ -615,6 +595,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key); cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster)); ExtractOpRole(cluster));
cinn_op_desc.Flush(); cinn_op_desc.Flush();
auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc); auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc);
// Add new links from or to the cinn launch op node // Add new links from or to the cinn launch op node
...@@ -639,21 +620,15 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, ...@@ -639,21 +620,15 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are // kCinnLaunchOp, and inputs ares cluster_inputs and outputs are
// cluster_outputs. // cluster_outputs.
// Meanwhile, move all links of cluster to the cinn op. // Meanwhile, move all links of cluster to the cinn op.
void ReplaceSubGraphWithCinnOpNode( void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals, const GraphNodeSet& cluster_internals,
int64_t compilation_key, int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) { Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph // Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster, AddCinnOpToGraph(
cluster_inputs, cluster, cluster_inputs, cluster_outputs, compilation_key, graph);
cluster_outputs,
compilation_key,
deny_var_set,
graph);
// Remove the cinn subgraph from graph // Remove the cinn subgraph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph); RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
} }
...@@ -667,9 +642,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -667,9 +642,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info; OpTransInfo trans_info;
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes()); auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) {
auto teller = [&allow_ops, &deny_ops, &trans_info, &deny_var_set](
const Node* node) {
const auto& node_name = node->Name(); const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node_name) != nullptr; node_name) != nullptr;
...@@ -679,10 +652,9 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -679,10 +652,9 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node); is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node);
} }
bool is_support = bool is_support = registered &&
registered && !trans_info.default_deny_ops().count(node_name) && !trans_info.default_deny_ops().count(node_name) &&
!is_dynamic && !is_dynamic;
(node->IsOp() && !trans_info.IsInplaceOp(*node->Op(), deny_var_set));
// if the op type is registered in CINN and allow_ops is not empty, return // if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops // true only when it is in allow_ops
if (!allow_ops.empty()) { if (!allow_ops.empty()) {
...@@ -714,19 +686,23 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -714,19 +686,23 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
return res; return res;
}; };
std::unordered_set<std::string> skip_gc_var_names; std::unordered_set<std::string> all_skip_gc_vars;
if (graph->Has(kSkipGcVarNames)) { if (graph->Has(kSkipGcVarNames)) {
skip_gc_var_names = all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames); graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
VLOG_IF(4, !all_skip_gc_vars.empty())
<< "All skip gc var names are: " << GetDebugInfo(all_skip_gc_vars);
} }
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes());
VLOG_IF(4, !deny_var_set.empty())
<< "All deny var names are: " << GetDebugInfo(deny_var_set);
auto* cinn_compiler = CinnCompiler::GetInstance(); auto* cinn_compiler = CinnCompiler::GetInstance();
for (const auto& node_vec : clusters) { for (const auto& node_vec : clusters) {
// Classify var node to inputs, outputs, and internals. // Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
auto deny_var_set = trans_info.GetDenyVarNames(cluster_set);
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set, AnalyseClusterVariables(cluster_set,
deny_var_set, deny_var_set,
...@@ -734,7 +710,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -734,7 +710,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
&cluster_outputs, &cluster_outputs,
&cluster_internals, &cluster_internals,
is_inference_stage, is_inference_stage,
skip_gc_var_names); all_skip_gc_vars);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set); VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs); VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
...@@ -747,8 +723,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -747,8 +723,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_set, cluster_internals, cluster_inputs, cluster_outputs); cluster_set, cluster_internals, cluster_inputs, cluster_outputs);
// Deliver the kSkipGcVarNames attr (if exists) to the subgraph // Deliver the kSkipGcVarNames attr (if exists) to the subgraph
if (graph->Has(kSkipGcVarNames)) { if (graph->Has(kSkipGcVarNames)) {
const auto& all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
auto& sub_skip_gc_vars = auto& sub_skip_gc_vars =
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames); subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars; sub_skip_gc_vars = all_skip_gc_vars;
...@@ -763,7 +737,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -763,7 +737,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_outputs, cluster_outputs,
cluster_internals, cluster_internals,
compilation_key, compilation_key,
deny_var_set,
graph); graph);
} }
} }
......
...@@ -39,6 +39,7 @@ constexpr char kOutputVars[] = "OutputVars"; ...@@ -39,6 +39,7 @@ constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] = constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph"; "mem_opt_var_info_from_main_graph";
constexpr char kSkipGcVarNames[] = "skip_gc_vars"; constexpr char kSkipGcVarNames[] = "skip_gc_vars";
constexpr char kInplaceVarNames[] = "InplaceVars";
using Name2VarInfoMap = using Name2VarInfoMap =
std::unordered_map<std::string, std::unordered_map<std::string,
...@@ -67,11 +68,8 @@ class OpTransInfo { ...@@ -67,11 +68,8 @@ class OpTransInfo {
std::unordered_set<std::string> GetDenyVarNames( std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const; const GraphNodeSet& cluster) const;
std::unordered_set<std::string> GetIgnoreInplaceVarNames( static std::unordered_set<std::string> GetInplaceVarNames(
const OpDesc& op_desc) const; const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs);
bool IsInplaceOp(const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const;
private: private:
DyOpCondT dynamic_op_cond_; DyOpCondT dynamic_op_cond_;
...@@ -79,9 +77,6 @@ class OpTransInfo { ...@@ -79,9 +77,6 @@ class OpTransInfo {
DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}}, DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}}; {"batch_norm_grad", {"ReserveSpace"}}};
DeParamCondT ignore_inplace_param_cond_{
{"batch_norm", {"MeanOut", "VarianceOut"}}};
std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"}; std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
}; };
......
...@@ -258,16 +258,15 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { ...@@ -258,16 +258,15 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const { std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
std::unordered_set<std::string> fetch_names; std::unordered_set<std::string> fetch_names;
fetch_names.reserve(fetch_var_names_.size()); fetch_names.reserve(fetch_var_names_.size());
std::for_each( std::for_each(fetch_var_names_.begin(),
fetch_var_names_.begin(),
fetch_var_names_.end(), fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) { [this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
var_model_to_program_map_.count(name), var_map_.count(name),
1, 1,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Cannot find %s in var_model_to_program_map_", name.c_str())); "Cannot find %s in var_map_", name.c_str()));
fetch_names.insert(var_model_to_program_map_.at(name)); fetch_names.insert(var_map_.at(name)->id);
}); });
return fetch_names; return fetch_names;
} }
......
...@@ -337,6 +337,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, ...@@ -337,6 +337,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
pair.first != "X") { pair.first != "X") {
continue; continue;
} }
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first == "Mask") {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" || if (pair.first == "LnScale" || pair.first == "LnBias" ||
...@@ -381,6 +386,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, ...@@ -381,6 +386,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
pair.first == "X" && dst_type == framework::proto::VarType::FP32) { pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue; continue;
} }
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first != "Mask" && dst_type == framework::proto::VarType::FP32) {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforwad") && if ((op_type == "fused_attention" || op_type == "fused_feedforwad") &&
dst_type == framework::proto::VarType::FP32) { dst_type == framework::proto::VarType::FP32) {
if (pair.first != "LnScale" && pair.first != "LnBias" && if (pair.first != "LnScale" && pair.first != "LnBias" &&
...@@ -428,6 +438,11 @@ NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type, ...@@ -428,6 +438,11 @@ NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
pair.first != "X") { pair.first != "X") {
continue; continue;
} }
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first == "Mask") {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) { if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" || if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" || pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
......
...@@ -1609,6 +1609,51 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() { ...@@ -1609,6 +1609,51 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() {
return output_names; return output_names;
} }
std::map<std::string, std::vector<int64_t>>
AnalysisPredictor::GetOutputTensorShape() {
std::map<std::string, std::vector<int64_t>> output_shapes;
std::vector<std::string> names = GetOutputNames();
for (std::string name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::PreconditionNotMet(
"Output %s does not exist.", name));
output_shapes[name] = var->GetShape();
}
return output_shapes;
}
std::map<std::string, paddle_infer::DataType>
AnalysisPredictor::GetOutputTypes() {
std::map<std::string, paddle_infer::DataType> output_type;
std::vector<std::string> names = GetOutputNames();
for (const auto &name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::PreconditionNotMet(
"Output %s does not exist inference_program_.", name));
auto dtype = var->GetDataType();
if (dtype == paddle::framework::proto::VarType::FP32) {
output_type[name] = paddle_infer::DataType::FLOAT32;
} else if (dtype == paddle::framework::proto::VarType::FP16) {
output_type[name] = paddle_infer::DataType::FLOAT16;
} else if (dtype == paddle::framework::proto::VarType::INT64) {
output_type[name] = paddle_infer::DataType::INT64;
} else if (dtype == paddle::framework::proto::VarType::INT32) {
output_type[name] = paddle_infer::DataType::INT32;
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
output_type[name] = paddle_infer::DataType::UINT8;
} else if (dtype == paddle::framework::proto::VarType::INT8) {
output_type[name] = paddle_infer::DataType::INT8;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported data type `%s` when get output dtype ", dtype));
}
}
return output_type;
}
std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor( std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
const std::string &name) { const std::string &name) {
framework::Scope *scope; framework::Scope *scope;
...@@ -2477,6 +2522,10 @@ std::vector<std::string> Predictor::GetInputNames() { ...@@ -2477,6 +2522,10 @@ std::vector<std::string> Predictor::GetInputNames() {
return predictor_->GetInputNames(); return predictor_->GetInputNames();
} }
std::map<std::string, std::vector<int64_t>> Predictor::GetInputTensorShape() {
return predictor_->GetInputTensorShape();
}
std::map<std::string, DataType> Predictor::GetInputTypes() { std::map<std::string, DataType> Predictor::GetInputTypes() {
return predictor_->GetInputTypes(); return predictor_->GetInputTypes();
} }
...@@ -2493,6 +2542,14 @@ std::unique_ptr<Tensor> Predictor::GetOutputHandle(const std::string &name) { ...@@ -2493,6 +2542,14 @@ std::unique_ptr<Tensor> Predictor::GetOutputHandle(const std::string &name) {
return predictor_->GetOutputTensor(name); return predictor_->GetOutputTensor(name);
} }
std::map<std::string, std::vector<int64_t>> Predictor::GetOutputTensorShape() {
return predictor_->GetOutputTensorShape();
}
std::map<std::string, DataType> Predictor::GetOutputTypes() {
return predictor_->GetOutputTypes();
}
bool Predictor::Run() { return predictor_->ZeroCopyRun(); } bool Predictor::Run() { return predictor_->ZeroCopyRun(); }
std::unique_ptr<Predictor> Predictor::Clone(void *stream) { std::unique_ptr<Predictor> Predictor::Clone(void *stream) {
......
...@@ -191,6 +191,18 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -191,6 +191,18 @@ class AnalysisPredictor : public PaddlePredictor {
/// \return the map of input names and type /// \return the map of input names and type
/// ///
std::map<std::string, paddle_infer::DataType> GetInputTypes() override; std::map<std::string, paddle_infer::DataType> GetInputTypes() override;
///
/// \brief Get all output names and their corresponding shapes
///
/// \return the map of output names and shapes
///
std::map<std::string, std::vector<int64_t>> GetOutputTensorShape() override;
///
/// \brief Get all output names and their corresponding type
///
/// \return the map of output names and type
///
std::map<std::string, paddle_infer::DataType> GetOutputTypes() override;
/// ///
/// \brief Run the prediction engine /// \brief Run the prediction engine
......
...@@ -106,6 +106,8 @@ TEST(AnalysisPredictor, analysis_on) { ...@@ -106,6 +106,8 @@ TEST(AnalysisPredictor, analysis_on) {
ASSERT_EQ(predictor->scope_->parent(), nullptr); ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get()); ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL); ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
ASSERT_EQ(predictor->GetOutputTypes().size(), 1UL);
ASSERT_EQ(predictor->GetOutputTensorShape().size(), 1UL);
// 2. Dummy Input Data // 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4}; int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor; PaddleTensor tensor;
...@@ -430,6 +432,8 @@ TEST(Predictor, Run) { ...@@ -430,6 +432,8 @@ TEST(Predictor, Run) {
auto predictor = CreatePredictor(config); auto predictor = CreatePredictor(config);
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL); ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
ASSERT_EQ(predictor->GetOutputTypes().size(), 1UL);
ASSERT_EQ(predictor->GetOutputTensorShape().size(), 1UL);
auto w0 = predictor->GetInputHandle("firstw"); auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw"); auto w1 = predictor->GetInputHandle("secondw");
......
...@@ -243,6 +243,19 @@ class PD_INFER_DECL PaddlePredictor { ...@@ -243,6 +243,19 @@ class PD_INFER_DECL PaddlePredictor {
/// \return Output tensor names. /// \return Output tensor names.
virtual std::vector<std::string> GetOutputNames() { return {}; } virtual std::vector<std::string> GetOutputNames() { return {}; }
/// \brief Get the output shape of the model.
/// \return A map contains all the output names and shape defined in the
/// model.
virtual std::map<std::string, std::vector<int64_t>> GetOutputTensorShape() {
return {};
}
/// \brief Get the output type of the model.
/// \return A map contains all the output names and type defined in the model.
virtual std::map<std::string, paddle_infer::DataType> GetOutputTypes() {
return {};
}
/// \brief Get the input ZeroCopyTensor by name. /// \brief Get the input ZeroCopyTensor by name.
/// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios. /// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.
/// The name is obtained from the GetInputNames() interface. /// The name is obtained from the GetInputNames() interface.
......
...@@ -92,6 +92,13 @@ class PD_INFER_DECL Predictor { ...@@ -92,6 +92,13 @@ class PD_INFER_DECL Predictor {
/// ///
explicit Predictor(const Config& config); explicit Predictor(const Config& config);
///
/// \brief Get all input names and their corresponding shapes
///
/// \return the map of input names and shape
///
std::map<std::string, std::vector<int64_t>> GetInputTensorShape();
/// ///
/// \brief Get all input names and their corresponding type /// \brief Get all input names and their corresponding type
/// ///
...@@ -136,6 +143,20 @@ class PD_INFER_DECL Predictor { ...@@ -136,6 +143,20 @@ class PD_INFER_DECL Predictor {
/// ///
std::unique_ptr<Tensor> GetOutputHandle(const std::string& name); std::unique_ptr<Tensor> GetOutputHandle(const std::string& name);
///
/// \brief Get all output names and their corresponding shapes
///
/// \return the map of output names and shape
///
std::map<std::string, std::vector<int64_t>> GetOutputTensorShape();
///
/// \brief Get all output names and their corresponding type
///
/// \return the map of output names and type
///
std::map<std::string, DataType> GetOutputTypes();
/// ///
/// \brief Clone to get the new predictor. thread safe. /// \brief Clone to get the new predictor. thread safe.
/// ///
......
...@@ -55,8 +55,9 @@ __pd_give PD_Config* PD_ConfigCreate() { ...@@ -55,8 +55,9 @@ __pd_give PD_Config* PD_ConfigCreate() {
} }
void PD_ConfigDestroy(__pd_take PD_Config* pd_config) { void PD_ConfigDestroy(__pd_take PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG; if (pd_config != NULL) {
delete reinterpret_cast<Config*>(config); delete reinterpret_cast<Config*>(pd_config);
}
} }
void PD_ConfigSetModel(__pd_keep PD_Config* pd_config, void PD_ConfigSetModel(__pd_keep PD_Config* pd_config,
...@@ -116,9 +117,12 @@ PD_Bool PD_ConfigUseFcPadding(__pd_keep PD_Config* pd_config) { ...@@ -116,9 +117,12 @@ PD_Bool PD_ConfigUseFcPadding(__pd_keep PD_Config* pd_config) {
void PD_ConfigEnableUseGpu(__pd_keep PD_Config* pd_config, void PD_ConfigEnableUseGpu(__pd_keep PD_Config* pd_config,
uint64_t memory_pool_init_size_mb, uint64_t memory_pool_init_size_mb,
int32_t device_id) { int32_t device_id,
PD_PrecisionType precision_mode) {
CHECK_AND_CONVERT_PD_CONFIG; CHECK_AND_CONVERT_PD_CONFIG;
config->EnableUseGpu(memory_pool_init_size_mb, device_id); config->EnableUseGpu(memory_pool_init_size_mb,
device_id,
ConvertToCxxPrecisionType(precision_mode));
} }
void PD_ConfigDisableGpu(__pd_keep PD_Config* pd_config) { void PD_ConfigDisableGpu(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG; CHECK_AND_CONVERT_PD_CONFIG;
...@@ -427,6 +431,14 @@ void PD_ConfigSetBfloat16Op(__pd_keep PD_Config* pd_config, ...@@ -427,6 +431,14 @@ void PD_ConfigSetBfloat16Op(__pd_keep PD_Config* pd_config,
} }
config->SetBfloat16Op(std::move(op_names)); config->SetBfloat16Op(std::move(op_names));
} }
void PD_ConfigEnableMkldnnInt8(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
config->EnableMkldnnInt8();
}
PD_Bool PD_ConfigMkldnnInt8Enabled(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
return config->mkldnn_int8_enabled();
}
PD_Bool PD_ConfigThreadLocalStreamEnabled(__pd_keep PD_Config* pd_config) { PD_Bool PD_ConfigThreadLocalStreamEnabled(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG; CHECK_AND_CONVERT_PD_CONFIG;
return config->thread_local_stream_enabled(); return config->thread_local_stream_enabled();
...@@ -484,6 +496,10 @@ void PD_ConfigEnableGpuMultiStream(__pd_keep PD_Config* pd_config) { ...@@ -484,6 +496,10 @@ void PD_ConfigEnableGpuMultiStream(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG; CHECK_AND_CONVERT_PD_CONFIG;
config->EnableGpuMultiStream(); config->EnableGpuMultiStream();
} }
void PD_ConfigSetExecStream(__pd_keep PD_Config* pd_config, void* stream) {
CHECK_AND_CONVERT_PD_CONFIG;
return config->SetExecStream(stream);
}
void PD_ConfigPartiallyRelease(__pd_keep PD_Config* pd_config) { void PD_ConfigPartiallyRelease(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG; CHECK_AND_CONVERT_PD_CONFIG;
config->PartiallyRelease(); config->PartiallyRelease();
......
...@@ -132,11 +132,13 @@ PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigUseFcPadding( ...@@ -132,11 +132,13 @@ PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigUseFcPadding(
/// \param[in] memory_pool_init_size_mb initial size of the GPU memory pool in /// \param[in] memory_pool_init_size_mb initial size of the GPU memory pool in
/// MB. /// MB.
/// \param[in] device_id device_id the GPU card to use. /// \param[in] device_id device_id the GPU card to use.
/// \param[in] precision_mode the precision used in Paddle-GPU inference.
/// ///
PADDLE_CAPI_EXPORT extern void PD_ConfigEnableUseGpu( PADDLE_CAPI_EXPORT extern void PD_ConfigEnableUseGpu(
__pd_keep PD_Config* pd_config, __pd_keep PD_Config* pd_config,
uint64_t memory_pool_init_size_mb, uint64_t memory_pool_init_size_mb,
int32_t device_id); int32_t device_id,
PD_PrecisionType precision_mode);
/// ///
/// \brief Turn off GPU. /// \brief Turn off GPU.
/// ///
...@@ -607,6 +609,22 @@ PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigMkldnnBfloat16Enabled( ...@@ -607,6 +609,22 @@ PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigMkldnnBfloat16Enabled(
PADDLE_CAPI_EXPORT extern void PD_ConfigSetBfloat16Op( PADDLE_CAPI_EXPORT extern void PD_ConfigSetBfloat16Op(
__pd_keep PD_Config* pd_config, size_t ops_num, const char** op_list); __pd_keep PD_Config* pd_config, size_t ops_num, const char** op_list);
/// ///
/// \brief Turn on MKLDNN int8.
///
/// \param[in] pd_onfig config
///
PADDLE_CAPI_EXPORT extern void PD_ConfigEnableMkldnnInt8(
__pd_keep PD_Config* pd_config);
///
/// \brief A boolean state telling whether to use the MKLDNN int8.
///
/// \param[in] pd_onfig config
/// \return Whether to use the MKLDNN int8.
///
PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigMkldnnInt8Enabled(
__pd_keep PD_Config* pd_config);
///
/// \brief Enable the GPU multi-computing stream feature. /// \brief Enable the GPU multi-computing stream feature.
/// NOTE: The current behavior of this interface is to bind the computation /// NOTE: The current behavior of this interface is to bind the computation
/// stream to the thread, and this behavior may be changed in the future. /// stream to the thread, and this behavior may be changed in the future.
...@@ -625,6 +643,12 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigEnableGpuMultiStream( ...@@ -625,6 +643,12 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigEnableGpuMultiStream(
PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigThreadLocalStreamEnabled( PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigThreadLocalStreamEnabled(
__pd_keep PD_Config* pd_config); __pd_keep PD_Config* pd_config);
/// ///
/// \brief Set execution stream. If not set a stream will be created
/// internally.
///
PADDLE_CAPI_EXPORT extern void PD_ConfigSetExecStream(
__pd_keep PD_Config* pd_config, void* stream);
///
/// \brief Specify the memory buffer of program and parameter. /// \brief Specify the memory buffer of program and parameter.
/// Used when model and params are loaded directly from memory. /// Used when model and params are loaded directly from memory.
/// ///
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/capi_exp/pd_predictor.h" #include "paddle/fluid/inference/capi_exp/pd_predictor.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/capi_exp/pd_config.h"
#include "paddle/fluid/inference/capi_exp/pd_types.h" #include "paddle/fluid/inference/capi_exp/pd_types.h"
#include "paddle/fluid/inference/capi_exp/pd_utils.h" #include "paddle/fluid/inference/capi_exp/pd_utils.h"
#include "paddle/fluid/inference/capi_exp/types_internal.h" #include "paddle/fluid/inference/capi_exp/types_internal.h"
...@@ -38,7 +39,6 @@ __pd_give PD_Predictor* PD_PredictorCreate(__pd_take PD_Config* pd_config) { ...@@ -38,7 +39,6 @@ __pd_give PD_Predictor* PD_PredictorCreate(__pd_take PD_Config* pd_config) {
paddle_infer::Config* config = paddle_infer::Config* config =
reinterpret_cast<paddle_infer::Config*>(pd_config); reinterpret_cast<paddle_infer::Config*>(pd_config);
pd_predictor->predictor = paddle_infer::CreatePredictor(*config); pd_predictor->predictor = paddle_infer::CreatePredictor(*config);
delete config;
return pd_predictor; return pd_predictor;
} }
...@@ -57,6 +57,30 @@ __pd_give PD_OneDimArrayCstr* PD_PredictorGetInputNames( ...@@ -57,6 +57,30 @@ __pd_give PD_OneDimArrayCstr* PD_PredictorGetInputNames(
return paddle_infer::CvtVecToOneDimArrayCstr(names); return paddle_infer::CvtVecToOneDimArrayCstr(names);
} }
__pd_give PD_IOInfos* PD_PredictorGetInputInfos(
__pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR;
std::vector<std::string> names = predictor->GetInputNames();
std::map<std::string, std::vector<int64_t>> input_shapes =
predictor->GetInputTensorShape();
std::map<std::string, paddle_infer::DataType> input_dtypes =
predictor->GetInputTypes();
PD_IOInfos* input_infos = new PD_IOInfos;
input_infos->size = names.size();
input_infos->io_info = names.empty() ? NULL : new PD_IOInfo*[names.size()];
for (size_t i = 0; i < names.size(); i++) {
const std::string& name = names[i];
input_infos->io_info[i] = new PD_IOInfo;
input_infos->io_info[i]->name = paddle_infer::CvtStrToCstr(name);
input_infos->io_info[i]->shape =
paddle_infer::CvtVecToOneDimArrayInt64(input_shapes[name]);
input_infos->io_info[i]->dtype =
paddle_infer::CvtFromCxxDatatype(input_dtypes[name]);
}
return input_infos;
}
__pd_give PD_OneDimArrayCstr* PD_PredictorGetOutputNames( __pd_give PD_OneDimArrayCstr* PD_PredictorGetOutputNames(
__pd_keep PD_Predictor* pd_predictor) { __pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR; CHECK_AND_CONVERT_PD_PREDICTOR;
...@@ -64,6 +88,30 @@ __pd_give PD_OneDimArrayCstr* PD_PredictorGetOutputNames( ...@@ -64,6 +88,30 @@ __pd_give PD_OneDimArrayCstr* PD_PredictorGetOutputNames(
return paddle_infer::CvtVecToOneDimArrayCstr(names); return paddle_infer::CvtVecToOneDimArrayCstr(names);
} }
__pd_give PD_IOInfos* PD_PredictorGetOutputInfos(
__pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR;
std::vector<std::string> names = predictor->GetOutputNames();
std::map<std::string, std::vector<int64_t>> output_shapes =
predictor->GetOutputTensorShape();
std::map<std::string, paddle_infer::DataType> output_dtypes =
predictor->GetOutputTypes();
PD_IOInfos* output_infos = new PD_IOInfos;
output_infos->size = names.size();
output_infos->io_info = names.empty() ? NULL : new PD_IOInfo*[names.size()];
for (size_t i = 0; i < names.size(); i++) {
const std::string& name = names[i];
output_infos->io_info[i] = new PD_IOInfo;
output_infos->io_info[i]->name = paddle_infer::CvtStrToCstr(name);
output_infos->io_info[i]->shape =
paddle_infer::CvtVecToOneDimArrayInt64(output_shapes[name]);
output_infos->io_info[i]->dtype =
paddle_infer::CvtFromCxxDatatype(output_dtypes[name]);
}
return output_infos;
}
size_t PD_PredictorGetInputNum(__pd_keep PD_Predictor* pd_predictor) { size_t PD_PredictorGetInputNum(__pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR; CHECK_AND_CONVERT_PD_PREDICTOR;
return predictor->GetInputNames().size(); return predictor->GetInputNames().size();
......
...@@ -30,6 +30,7 @@ typedef struct PD_Predictor PD_Predictor; ...@@ -30,6 +30,7 @@ typedef struct PD_Predictor PD_Predictor;
typedef struct PD_Config PD_Config; typedef struct PD_Config PD_Config;
typedef struct PD_Tensor PD_Tensor; typedef struct PD_Tensor PD_Tensor;
typedef struct PD_OneDimArrayCstr PD_OneDimArrayCstr; typedef struct PD_OneDimArrayCstr PD_OneDimArrayCstr;
typedef struct PD_IOInfos PD_IOInfos;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
...@@ -60,6 +61,14 @@ PADDLE_CAPI_EXPORT extern __pd_give PD_Predictor* PD_PredictorClone( ...@@ -60,6 +61,14 @@ PADDLE_CAPI_EXPORT extern __pd_give PD_Predictor* PD_PredictorClone(
PADDLE_CAPI_EXPORT extern __pd_give PD_OneDimArrayCstr* PADDLE_CAPI_EXPORT extern __pd_give PD_OneDimArrayCstr*
PD_PredictorGetInputNames(__pd_keep PD_Predictor* pd_predictor); PD_PredictorGetInputNames(__pd_keep PD_Predictor* pd_predictor);
/// ///
/// \brief Get the input infos(name/shape/dtype)
///
/// \param[in] pd_predictor predictor
/// \return input infos(name/shape/dtype)
///
PADDLE_CAPI_EXPORT extern __pd_give PD_IOInfos* PD_PredictorGetInputInfos(
__pd_keep PD_Predictor* pd_predictor);
///
/// \brief Get the output names /// \brief Get the output names
/// ///
/// \param[in] pd_predictor predictor /// \param[in] pd_predictor predictor
...@@ -67,7 +76,14 @@ PD_PredictorGetInputNames(__pd_keep PD_Predictor* pd_predictor); ...@@ -67,7 +76,14 @@ PD_PredictorGetInputNames(__pd_keep PD_Predictor* pd_predictor);
/// ///
PADDLE_CAPI_EXPORT extern __pd_give PD_OneDimArrayCstr* PADDLE_CAPI_EXPORT extern __pd_give PD_OneDimArrayCstr*
PD_PredictorGetOutputNames(__pd_keep PD_Predictor* pd_predictor); PD_PredictorGetOutputNames(__pd_keep PD_Predictor* pd_predictor);
///
/// \brief Get the output infos(name/shape/dtype)
///
/// \param[in] pd_predictor predictor
/// \return output infos(name/shape/dtype)
///
PADDLE_CAPI_EXPORT extern __pd_give PD_IOInfos* PD_PredictorGetOutputInfos(
__pd_keep PD_Predictor* pd_predictor);
/// ///
/// \brief Get the input number /// \brief Get the input number
/// ///
......
...@@ -29,6 +29,11 @@ typedef struct PD_OneDimArraySize { ...@@ -29,6 +29,11 @@ typedef struct PD_OneDimArraySize {
size_t* data; size_t* data;
} PD_OneDimArraySize; // std::vector<size_t> } PD_OneDimArraySize; // std::vector<size_t>
typedef struct PD_OneDimArrayInt64 {
size_t size;
int64_t* data;
} PD_OneDimArrayInt64; // std::vector<int64_t>
typedef struct PD_OneDimArrayCstr { typedef struct PD_OneDimArrayCstr {
size_t size; size_t size;
char** data; char** data;
...@@ -43,3 +48,14 @@ typedef struct PD_TwoDimArraySize { ...@@ -43,3 +48,14 @@ typedef struct PD_TwoDimArraySize {
size_t size; size_t size;
PD_OneDimArraySize** data; PD_OneDimArraySize** data;
} PD_TwoDimArraySize; // std::vector<std::vector<size_t>> } PD_TwoDimArraySize; // std::vector<std::vector<size_t>>
typedef struct PD_IOInfo {
PD_Cstr* name;
PD_OneDimArrayInt64* shape;
PD_DataType dtype;
} PD_IOInfo; // input or output info
typedef struct PD_IOInfos {
size_t size;
PD_IOInfo** io_info;
} PD_IOInfos; // inputs or outputs info
...@@ -11,12 +11,10 @@ ...@@ -11,12 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/inference/capi_exp/pd_utils.h"
#include <string> #include <string>
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/capi_exp/pd_utils.h"
#include "paddle/fluid/inference/capi_exp/utils_internal.h" #include "paddle/fluid/inference/capi_exp/utils_internal.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -62,6 +60,7 @@ ...@@ -62,6 +60,7 @@
ONE_DIM_ARRAY_UTILS_FUNC_IMPL(int32_t, Int32, int) ONE_DIM_ARRAY_UTILS_FUNC_IMPL(int32_t, Int32, int)
ONE_DIM_ARRAY_UTILS_FUNC_IMPL(size_t, Size, size_t) ONE_DIM_ARRAY_UTILS_FUNC_IMPL(size_t, Size, size_t)
ONE_DIM_ARRAY_UTILS_FUNC_IMPL(int64_t, Int64, int64_t)
#undef ONE_DIM_ARRAY_UTILS_FUNC_IMPL #undef ONE_DIM_ARRAY_UTILS_FUNC_IMPL
#undef CONVERT_ONE_DIM_ARRAY_TO_VEC #undef CONVERT_ONE_DIM_ARRAY_TO_VEC
...@@ -178,6 +177,38 @@ TWO_DIM_ARRAY_UTILS_FUNC_IMPL(size_t, Size, size_t) ...@@ -178,6 +177,38 @@ TWO_DIM_ARRAY_UTILS_FUNC_IMPL(size_t, Size, size_t)
#undef CONVERT_VEC_TO_TWO_DIM_ARRAY #undef CONVERT_VEC_TO_TWO_DIM_ARRAY
#undef DESTROY_TWO_DIM_ARRAY #undef DESTROY_TWO_DIM_ARRAY
#ifdef __cplusplus
extern "C" {
#endif
void PD_IOInfoDestroy(__pd_take PD_IOInfo* io_info) {
if (io_info != NULL) {
PD_CstrDestroy(io_info->name);
io_info->name = NULL;
PD_OneDimArrayInt64Destroy(io_info->shape);
io_info->shape = NULL;
delete io_info;
}
}
void PD_IOInfosDestroy(__pd_take PD_IOInfos* io_infos) {
if (io_infos != NULL) {
if (io_infos->size != 0) {
for (size_t index = 0; index < io_infos->size; ++index) {
PD_IOInfoDestroy(io_infos->io_info[index]);
}
io_infos->size = 0;
}
delete[] io_infos->io_info;
io_infos->io_info = NULL;
delete io_infos;
}
}
#ifdef __cplusplus
} // extern "C"
#endif
namespace paddle_infer { namespace paddle_infer {
PlaceType CvtToCxxPlaceType(PD_PlaceType place_type) { PlaceType CvtToCxxPlaceType(PD_PlaceType place_type) {
......
...@@ -41,6 +41,14 @@ extern "C" { ...@@ -41,6 +41,14 @@ extern "C" {
PADDLE_CAPI_EXPORT extern void PD_OneDimArrayInt32Destroy( PADDLE_CAPI_EXPORT extern void PD_OneDimArrayInt32Destroy(
__pd_take PD_OneDimArrayInt32* array); __pd_take PD_OneDimArrayInt32* array);
///
/// \brief Destroy the PD_OneDimArrayInt64 object pointed to by the pointer.
///
/// \param[in] array pointer to the PD_OneDimArrayInt64 object.
///
PADDLE_CAPI_EXPORT extern void PD_OneDimArrayInt64Destroy(
__pd_take PD_OneDimArrayInt64* array);
/// ///
/// \brief Destroy the PD_OneDimArrayCstr object pointed to by the pointer. /// \brief Destroy the PD_OneDimArrayCstr object pointed to by the pointer.
/// ///
...@@ -74,6 +82,21 @@ PADDLE_CAPI_EXPORT extern void PD_TwoDimArraySizeDestroy( ...@@ -74,6 +82,21 @@ PADDLE_CAPI_EXPORT extern void PD_TwoDimArraySizeDestroy(
/// ///
PADDLE_CAPI_EXPORT extern void PD_CstrDestroy(__pd_take PD_Cstr* cstr); PADDLE_CAPI_EXPORT extern void PD_CstrDestroy(__pd_take PD_Cstr* cstr);
///
/// \brief Destroy the PD_IOInfo object pointed to by the pointer.
///
/// \param[in] cstr pointer to the PD_IOInfo object.
///
PADDLE_CAPI_EXPORT extern void PD_IOInfoDestroy(__pd_take PD_IOInfo* io_info);
///
/// \brief Destroy the PD_IOInfos object pointed to by the pointer.
///
/// \param[in] cstr pointer to the PD_IOInfos object.
///
PADDLE_CAPI_EXPORT extern void PD_IOInfosDestroy(
__pd_take PD_IOInfos* io_infos);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
...@@ -44,6 +44,16 @@ namespace paddle_infer { ...@@ -44,6 +44,16 @@ namespace paddle_infer {
__pd_give PD_OneDimArrayInt32* CvtVecToOneDimArrayInt32( __pd_give PD_OneDimArrayInt32* CvtVecToOneDimArrayInt32(
const std::vector<int>& vec); const std::vector<int>& vec);
///
/// \brief Convert the 'std::vector<int64_t>' object to a 'PD_OneDimArrayInt64'
/// object.
///
/// \param[in] vec source object.
/// \return target object.
///
__pd_give PD_OneDimArrayInt64* CvtVecToOneDimArrayInt64(
const std::vector<int64_t>& vec);
/// ///
/// \brief Convert the 'PD_OneDimArrayInt32' object to a 'std::vector<int>' /// \brief Convert the 'PD_OneDimArrayInt32' object to a 'std::vector<int>'
/// object. /// object.
...@@ -54,6 +64,16 @@ __pd_give PD_OneDimArrayInt32* CvtVecToOneDimArrayInt32( ...@@ -54,6 +64,16 @@ __pd_give PD_OneDimArrayInt32* CvtVecToOneDimArrayInt32(
std::vector<int> CvtOneDimArrayToVecInt32( std::vector<int> CvtOneDimArrayToVecInt32(
__pd_keep const PD_OneDimArrayInt32* array); __pd_keep const PD_OneDimArrayInt32* array);
///
/// \brief Convert the 'PD_OneDimArrayInt64' object to a 'std::vector<int64_t>'
/// object.
///
/// \param[in] array source object.
/// \return target object.
///
std::vector<int64_t> CvtOneDimArrayToVecInt64(
__pd_keep const PD_OneDimArrayInt64* array);
/// ///
/// \brief Convert the 'std::vector<size_t>' object to a 'PD_OneDimArraySize' /// \brief Convert the 'std::vector<size_t>' object to a 'PD_OneDimArraySize'
/// object. /// object.
......
...@@ -161,7 +161,8 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_inference_Config_enableUseGpu( ...@@ -161,7 +161,8 @@ JNIEXPORT void JNICALL Java_com_baidu_paddle_inference_Config_enableUseGpu(
jint deviceId) { jint deviceId) {
PD_ConfigEnableUseGpu(reinterpret_cast<PD_Config*>(cppPaddleConfigPointer), PD_ConfigEnableUseGpu(reinterpret_cast<PD_Config*>(cppPaddleConfigPointer),
(uint64_t)memorySize, (uint64_t)memorySize,
(int32_t)deviceId); (int32_t)deviceId,
0);
} }
JNIEXPORT void JNICALL Java_com_baidu_paddle_inference_Config_disableGpu( JNIEXPORT void JNICALL Java_com_baidu_paddle_inference_Config_disableGpu(
......
...@@ -157,7 +157,7 @@ func (config *Config) UseFcPadding() bool { ...@@ -157,7 +157,7 @@ func (config *Config) UseFcPadding() bool {
/// \param deviceId the GPU card to use. /// \param deviceId the GPU card to use.
/// ///
func (config *Config) EnableUseGpu(memorySize uint64, deviceId int32) { func (config *Config) EnableUseGpu(memorySize uint64, deviceId int32) {
C.PD_ConfigEnableUseGpu(config.c, C.uint64_t(memorySize), C.int32_t(deviceId)) C.PD_ConfigEnableUseGpu(config.c, C.uint64_t(memorySize), C.int32_t(deviceId), 0)
} }
/// ///
......
...@@ -46,6 +46,7 @@ class CastOpConverter : public OpConverter { ...@@ -46,6 +46,7 @@ class CastOpConverter : public OpConverter {
layer->setOutputType(0, nvinfer1::DataType::kBOOL); layer->setOutputType(0, nvinfer1::DataType::kBOOL);
break; break;
case 2: // INT32 = 2 case 2: // INT32 = 2
case 3: // INT64 = 3 there is no int64 in tensorrt subgraph
layer->setOutputType(0, nvinfer1::DataType::kINT32); layer->setOutputType(0, nvinfer1::DataType::kINT32);
break; break;
case 4: // FP16 = 4 case 4: // FP16 = 4
......
...@@ -19,6 +19,10 @@ limitations under the License. */ ...@@ -19,6 +19,10 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#if defined(PADDLE_WITH_CUDA)
#include <cuda_runtime.h>
#endif
#include "paddle/fluid/inference/capi_exp/pd_inference_api.h" #include "paddle/fluid/inference/capi_exp/pd_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h" #include "paddle/fluid/inference/tests/api/tester_helper.h"
...@@ -37,7 +41,7 @@ TEST(PD_Config, gpu_interface) { ...@@ -37,7 +41,7 @@ TEST(PD_Config, gpu_interface) {
PD_ConfigSetModel(config, prog_file.c_str(), param_file.c_str()); PD_ConfigSetModel(config, prog_file.c_str(), param_file.c_str());
PD_ConfigSetOptimCacheDir(config, opt_cache_dir.c_str()); PD_ConfigSetOptimCacheDir(config, opt_cache_dir.c_str());
PD_ConfigEnableUseGpu(config, 100, 0); PD_ConfigEnableUseGpu(config, 100, 0, 0);
bool use_gpu = PD_ConfigUseGpu(config); bool use_gpu = PD_ConfigUseGpu(config);
EXPECT_TRUE(use_gpu); EXPECT_TRUE(use_gpu);
int init_size = PD_ConfigMemoryPoolInitSizeMb(config); int init_size = PD_ConfigMemoryPoolInitSizeMb(config);
...@@ -84,6 +88,14 @@ TEST(PD_Config, gpu_interface) { ...@@ -84,6 +88,14 @@ TEST(PD_Config, gpu_interface) {
bool thread_local_thread = PD_ConfigThreadLocalStreamEnabled(config); bool thread_local_thread = PD_ConfigThreadLocalStreamEnabled(config);
EXPECT_TRUE(thread_local_thread); EXPECT_TRUE(thread_local_thread);
#if defined(PADDLE_WITH_CUDA)
{
cudaStream_t external_stream;
cudaStreamCreate(&external_stream);
PD_ConfigSetExecStream(config, external_stream);
}
#endif
PD_ConfigDisableGpu(config); PD_ConfigDisableGpu(config);
PD_ConfigDestroy(config); PD_ConfigDestroy(config);
} }
...@@ -104,7 +116,7 @@ TEST(PD_Config, use_gpu) { ...@@ -104,7 +116,7 @@ TEST(PD_Config, use_gpu) {
const char* model_dir_ = PD_ConfigGetModelDir(config); const char* model_dir_ = PD_ConfigGetModelDir(config);
LOG(INFO) << model_dir_; LOG(INFO) << model_dir_;
PD_ConfigEnableUseGpu(config, 100, 0); PD_ConfigEnableUseGpu(config, 100, 0, 0);
bool use_gpu = PD_ConfigUseGpu(config); bool use_gpu = PD_ConfigUseGpu(config);
EXPECT_TRUE(use_gpu); EXPECT_TRUE(use_gpu);
int device_id = PD_ConfigGpuDeviceId(config); int device_id = PD_ConfigGpuDeviceId(config);
...@@ -142,7 +154,7 @@ TEST(PD_Config, use_gpu) { ...@@ -142,7 +154,7 @@ TEST(PD_Config, use_gpu) {
TEST(PD_Config, trt_int8) { TEST(PD_Config, trt_int8) {
std::string model_dir = FLAGS_infer_model + "/mobilenet"; std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_Config* config = PD_ConfigCreate(); PD_Config* config = PD_ConfigCreate();
PD_ConfigEnableUseGpu(config, 100, 0); PD_ConfigEnableUseGpu(config, 100, 0, 0);
PD_ConfigEnableTensorRtEngine( PD_ConfigEnableTensorRtEngine(
config, 1 << 20, 1, 3, PD_PRECISION_INT8, FALSE, TRUE); config, 1 << 20, 1, 3, PD_PRECISION_INT8, FALSE, TRUE);
bool trt_enable = PD_ConfigTensorRtEngineEnabled(config); bool trt_enable = PD_ConfigTensorRtEngineEnabled(config);
...@@ -153,7 +165,7 @@ TEST(PD_Config, trt_int8) { ...@@ -153,7 +165,7 @@ TEST(PD_Config, trt_int8) {
TEST(PD_Config, trt_fp16) { TEST(PD_Config, trt_fp16) {
std::string model_dir = FLAGS_infer_model + "/mobilenet"; std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_Config* config = PD_ConfigCreate(); PD_Config* config = PD_ConfigCreate();
PD_ConfigEnableUseGpu(config, 100, 0); PD_ConfigEnableUseGpu(config, 100, 0, 0);
PD_ConfigEnableTensorRtEngine( PD_ConfigEnableTensorRtEngine(
config, 1 << 20, 1, 3, PD_PRECISION_HALF, FALSE, FALSE); config, 1 << 20, 1, 3, PD_PRECISION_HALF, FALSE, FALSE);
bool trt_enable = PD_ConfigTensorRtEngineEnabled(config); bool trt_enable = PD_ConfigTensorRtEngineEnabled(config);
......
...@@ -37,6 +37,9 @@ void predictor_run() { ...@@ -37,6 +37,9 @@ void predictor_run() {
PD_OneDimArrayCstr* input_names = PD_PredictorGetInputNames(predictor); PD_OneDimArrayCstr* input_names = PD_PredictorGetInputNames(predictor);
LOG(INFO) << "The inputs' size is: " << input_names->size; LOG(INFO) << "The inputs' size is: " << input_names->size;
EXPECT_EQ(input_names->size, 2u); EXPECT_EQ(input_names->size, 2u);
PD_IOInfos* in_infos = PD_PredictorGetInputInfos(predictor);
EXPECT_EQ(in_infos->size, 2u);
PD_IOInfos* out_infos = PD_PredictorGetOutputInfos(predictor);
int32_t shape_0[4] = {1, 3, 224, 224}; int32_t shape_0[4] = {1, 3, 224, 224};
float data_0[1 * 3 * 224 * 224] = {0}; float data_0[1 * 3 * 224 * 224] = {0};
...@@ -79,6 +82,8 @@ void predictor_run() { ...@@ -79,6 +82,8 @@ void predictor_run() {
PD_TensorDestroy(input_1); PD_TensorDestroy(input_1);
PD_TensorDestroy(input_0); PD_TensorDestroy(input_0);
PD_OneDimArrayCstrDestroy(input_names); PD_OneDimArrayCstrDestroy(input_names);
PD_IOInfosDestroy(in_infos);
PD_IOInfosDestroy(out_infos);
PD_PredictorDestroy(predictor); PD_PredictorDestroy(predictor);
} }
......
...@@ -85,6 +85,10 @@ TEST(PD_Config, interface) { ...@@ -85,6 +85,10 @@ TEST(PD_Config, interface) {
PD_ConfigEnableMkldnnBfloat16(config); PD_ConfigEnableMkldnnBfloat16(config);
PD_ConfigSetBfloat16Op(config, 1, &ops_name); PD_ConfigSetBfloat16Op(config, 1, &ops_name);
PD_ConfigEnableMkldnnInt8(config);
bool mkldnn_int8_enabled = PD_ConfigMkldnnInt8Enabled(config);
EXPECT_TRUE(mkldnn_int8_enabled);
#endif #endif
PD_ConfigEnableONNXRuntime(config); PD_ConfigEnableONNXRuntime(config);
......
...@@ -198,8 +198,7 @@ void RefcountedMemoryMapAllocation::close() { ...@@ -198,8 +198,7 @@ void RefcountedMemoryMapAllocation::close() {
MemoryMapAllocationPool::Instance().Insert(MemoryMapInfo( MemoryMapAllocationPool::Instance().Insert(MemoryMapInfo(
flags_, map_size_ - mmap_alignment, ipc_name_, map_ptr_)); flags_, map_size_ - mmap_alignment, ipc_name_, map_ptr_));
} else { } else {
if (info->refcount == 0 && if (info->refcount == 0) {
shm_open(ipc_name_.c_str(), O_RDWR, (mode_t)0600) != -1) {
shm_unlink(ipc_name_.c_str()); shm_unlink(ipc_name_.c_str());
VLOG(6) << "shm_unlink file: " << ipc_name_; VLOG(6) << "shm_unlink file: " << ipc_name_;
} }
......
...@@ -305,24 +305,15 @@ PD_REGISTER_GENERAL_KERNEL( ...@@ -305,24 +305,15 @@ PD_REGISTER_GENERAL_KERNEL(
ALL_LAYOUT, ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>, paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {} ALL_DTYPE) {}
#endif
#elif defined(PADDLE_WITH_CUSTOM_DEVICE) #ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_REGISTER_GENERAL_KERNEL( namespace paddle {
feed_dense_tensor, namespace operators {
custom_cpu, template void FeedDenseTensorKernel<phi::CustomContext>(
ALL_LAYOUT, const phi::CustomContext& dev_ctx,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>, const phi::ExtendedTensor& x,
ALL_DTYPE) {} int col,
PD_REGISTER_GENERAL_KERNEL( phi::DenseTensor* out);
feed_sparse_coo_tensor, } // namespace operators
custom_cpu, } // namespace paddle
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif #endif
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h" #include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/phi/backends/device_manager.h" #include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/kernel_registry.h"
#define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \ #define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \
static paddle::framework::OpKernelRegistrar<phi::CustomPlace, __VA_ARGS__> \ static paddle::framework::OpKernelRegistrar<phi::CustomPlace, __VA_ARGS__> \
...@@ -26,10 +27,30 @@ limitations under the License. */ ...@@ -26,10 +27,30 @@ limitations under the License. */
paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \ paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \
__op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch(); __op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch();
#define REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL( \
kernel_name, dev_type, layout, kernel_fn) \
static phi::KernelRegistrar \
__reg_custom_device_phi_kernel_##kernel_name##_##backend##_##layout( \
phi::RegType::INNER, \
#kernel_name, \
dev_type, \
DATALAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \
PHI_VARIADIC_KERNEL(kernel_fn))
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out);
void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_type = dev_type.c_str(); auto device_type = dev_type.c_str();
/* see [Why use single type kernel] */ /* see [Why use single type kernel] */
REGISTER_OP_CUSTOM_DEVICE_KERNEL( REGISTER_OP_CUSTOM_DEVICE_KERNEL(
...@@ -66,9 +87,16 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { ...@@ -66,9 +87,16 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>, LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>,
paddle::operators:: paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>); LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL(
feed_dense_tensor,
device_type,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>);
#endif
} }
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#undef REGISTER_OP_CUSTOM_DEVICE_KERNEL #undef REGISTER_OP_CUSTOM_DEVICE_KERNEL
#undef REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL
...@@ -56,6 +56,7 @@ class SequencePadOp : public framework::OperatorWithKernel { ...@@ -56,6 +56,7 @@ class SequencePadOp : public framework::OperatorWithKernel {
auto pad_value_dims = ctx->GetInputDim("PadValue"); auto pad_value_dims = ctx->GetInputDim("PadValue");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
pad_value_dims == phi::make_ddim({1}) || pad_value_dims == phi::make_ddim({1}) ||
pad_value_dims == phi::make_ddim({}) ||
pad_value_dims == time_step_dims, pad_value_dims == time_step_dims,
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <string.h>
#include <memory> #include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
...@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape, ...@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s", "We only support float32/float16 for full, but we got data type: %s",
phi::DataTypeToString(dtype))); phi::DataTypeToString(dtype)));
if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>()); op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype)); op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput( op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()}); "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
......
...@@ -32,7 +32,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { ...@@ -32,7 +32,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0); auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true); tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp); auto grad_x_tmp = multiply<T>(grad_out, tmp);
set_output<T>(grad_x_tmp.impl(), grad_x); set_output<T>(grad_x_tmp, grad_x);
} }
template <typename T> template <typename T>
...@@ -53,7 +53,7 @@ void subtract_grad(const Tensor& x, ...@@ -53,7 +53,7 @@ void subtract_grad(const Tensor& x,
auto dy_reduce_res = sum<T>( auto dy_reduce_res = sum<T>(
scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy); set_output<T>(dy_tmp, dy);
} }
} else { } else {
by_pass<T>(scale_out_grad, dy); by_pass<T>(scale_out_grad, dy);
...@@ -69,7 +69,7 @@ void subtract_grad(const Tensor& x, ...@@ -69,7 +69,7 @@ void subtract_grad(const Tensor& x,
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx); set_output<T>(dx_tmp, dx);
} }
} else { } else {
by_pass<T>(out_grad, dx); by_pass<T>(out_grad, dx);
...@@ -94,7 +94,7 @@ void add_grad(const Tensor& x, ...@@ -94,7 +94,7 @@ void add_grad(const Tensor& x,
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy); set_output<T>(dy_tmp, dy);
} }
} else { } else {
...@@ -111,7 +111,7 @@ void add_grad(const Tensor& x, ...@@ -111,7 +111,7 @@ void add_grad(const Tensor& x,
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx); set_output<T>(dx_tmp, dx);
} }
} else { } else {
by_pass<T>(out_grad, dx); by_pass<T>(out_grad, dx);
...@@ -139,6 +139,9 @@ void sum_grad(const Tensor& x, ...@@ -139,6 +139,9 @@ void sum_grad(const Tensor& x,
reduce_all = false; reduce_all = false;
} }
auto x_grad_tmp = Tensor(); auto x_grad_tmp = Tensor();
if (x_dim_size == 1) {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} else {
if (!keepdim) { if (!keepdim) {
auto axis_ = std::vector<int64_t>(); auto axis_ = std::vector<int64_t>();
if (reduce_all) { if (reduce_all) {
...@@ -153,8 +156,9 @@ void sum_grad(const Tensor& x, ...@@ -153,8 +156,9 @@ void sum_grad(const Tensor& x,
} else { } else {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim)); x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} }
}
set_output<T>(x_grad_tmp.impl(), x_grad); set_output<T>(x_grad_tmp, x_grad);
} }
template <typename T> template <typename T>
...@@ -175,36 +179,36 @@ void divide_grad(const Tensor& x, ...@@ -175,36 +179,36 @@ void divide_grad(const Tensor& x,
// Maybe need reduce here // Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) { if (!reduce_dim.size()) {
set_output<T>(dy_res.impl(), dy); set_output<T>(dy_res, dy);
} else { } else {
auto dy_reduce_res = auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims())); auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy); set_output<T>(dy_tmp, dy);
} }
} else { } else {
set_output<T>(dy_res.impl(), dy); set_output<T>(dy_res, dy);
} }
} // indicate we will compute dy } // indicate we will compute dy
if (dx) { if (dx) {
// dx = (1/y) * dout // dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0); auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y); auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad); auto dx_res = multiply<T>(tmp0, out_grad);
if (y.dims() != x.dims()) { if (y.dims() != x.dims()) {
// Maybe need reduce here // Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) { if (!reduce_dim.size()) {
set_output<T>(dx_res.impl(), dx); set_output<T>(dx_res, dx);
} else { } else {
auto dx_reduce_res = auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims())); auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx); set_output<T>(dx_tmp, dx);
} }
} else { } else {
set_output<T>(dx_res.impl(), dx); set_output<T>(dx_res, dx);
} }
} // indicate we will compute dx } // indicate we will compute dx
} }
...@@ -215,7 +219,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { ...@@ -215,7 +219,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5); auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out); auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp); auto x_grad_tmp = multiply<T>(out_grad, tmp);
set_output<T>(x_grad_tmp.impl(), x_grad); set_output<T>(x_grad_tmp, x_grad);
} }
} }
...@@ -231,7 +235,7 @@ void multiply_grad(const Tensor& x, ...@@ -231,7 +235,7 @@ void multiply_grad(const Tensor& x,
if (x.dims() != y.dims()) { if (x.dims() != y.dims()) {
auto axes = get_reduce_dims(x.dims(), y.dims()); auto axes = get_reduce_dims(x.dims(), y.dims());
if (!axes.size()) { if (!axes.size()) {
set_output<T>(x_grad_unreduce.impl(), x_grad); set_output<T>(x_grad_unreduce, x_grad);
} else { } else {
auto x_grad_reduced = sum<T>(x_grad_unreduce, auto x_grad_reduced = sum<T>(x_grad_unreduce,
phi::vectorize(axes), phi::vectorize(axes),
...@@ -240,10 +244,10 @@ void multiply_grad(const Tensor& x, ...@@ -240,10 +244,10 @@ void multiply_grad(const Tensor& x,
if (x_grad_reduced.dims().size() != x.dims().size()) { if (x_grad_reduced.dims().size() != x.dims().size()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape()); x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
} }
set_output<T>(x_grad_reduced.impl(), x_grad); set_output<T>(x_grad_reduced, x_grad);
} }
} else { } else {
set_output<T>(x_grad_unreduce.impl(), x_grad); set_output<T>(x_grad_unreduce, x_grad);
} }
} }
if (y_grad) { if (y_grad) {
...@@ -251,7 +255,7 @@ void multiply_grad(const Tensor& x, ...@@ -251,7 +255,7 @@ void multiply_grad(const Tensor& x,
if (y.dims() != x.dims()) { if (y.dims() != x.dims()) {
auto axes = get_reduce_dims(y.dims(), x.dims()); auto axes = get_reduce_dims(y.dims(), x.dims());
if (!axes.size()) { if (!axes.size()) {
set_output<T>(y_grad_unreduce.impl(), y_grad); set_output<T>(y_grad_unreduce, y_grad);
} else { } else {
auto y_grad_reduced = sum<T>(y_grad_unreduce, auto y_grad_reduced = sum<T>(y_grad_unreduce,
phi::vectorize(axes), phi::vectorize(axes),
...@@ -260,10 +264,10 @@ void multiply_grad(const Tensor& x, ...@@ -260,10 +264,10 @@ void multiply_grad(const Tensor& x,
if (y_grad_reduced.dims().size() != y.dims().size()) { if (y_grad_reduced.dims().size() != y.dims().size()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape()); y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
} }
set_output<T>(y_grad_reduced.impl(), y_grad); set_output<T>(y_grad_reduced, y_grad);
} }
} else { } else {
set_output<T>(y_grad_unreduce.impl(), y_grad); set_output<T>(y_grad_unreduce, y_grad);
} }
} }
} }
...@@ -284,7 +288,7 @@ void expand_grad(const Tensor& x, ...@@ -284,7 +288,7 @@ void expand_grad(const Tensor& x,
if (reduced.dims().size() != x.dims().size()) { if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape()); reduced = reshape<T>(reduced, x.shape());
} }
set_output<T>(reduced.impl(), x_grad); set_output<T>(reduced, x_grad);
} }
} else { } else {
by_pass<T>(out_grad, x_grad); by_pass<T>(out_grad, x_grad);
...@@ -295,7 +299,7 @@ void expand_grad(const Tensor& x, ...@@ -295,7 +299,7 @@ void expand_grad(const Tensor& x,
template <typename T> template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
set_output<T>(multiply<T>(out_grad, out).impl(), x_grad); set_output<T>(multiply<T>(out_grad, out), x_grad);
} }
} }
......
...@@ -49,7 +49,6 @@ void set_output<Tensor>(const paddle::experimental::Tensor& x_tmp, ...@@ -49,7 +49,6 @@ void set_output<Tensor>(const paddle::experimental::Tensor& x_tmp,
template <> template <>
void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) { void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) {
set_output<Tensor>(x, out); set_output<Tensor>(x, out);
// out->set_impl(x.impl());
} }
} // namespace prim } // namespace prim
......
...@@ -69,7 +69,6 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x, ...@@ -69,7 +69,6 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
op->InferVarType(block); op->InferVarType(block);
op->InferShape(*block); op->InferShape(*block);
set_output<DescTensor>(new_out, out); set_output<DescTensor>(new_out, out);
// out->set_impl(new_out.impl());
} }
} // namespace prim } // namespace prim
......
...@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) { ...@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle::experimental::Tensor out0 = tanh_ad_func(tensor0); paddle::experimental::Tensor out0 = tanh_ad_func(tensor0);
std::vector<paddle::experimental::Tensor> outs0 = {out0}; std::vector<paddle::experimental::Tensor> outs0 = {out0};
// Disable prim // Disable prim
PrimCommonUtils::SetPrimEnabled(false); PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward // 4. Run Backward
egr::Backward(outs0, {}, false); egr::Backward(outs0, {}, false);
paddle::experimental::Tensor out1 = tanh_ad_func(tensor1); paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1}; std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim // Disable prim
PrimCommonUtils::SetPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward // 4. Run Backward
::egr::Backward(outs1, {}, false); ::egr::Backward(outs1, {}, false);
VLOG(7) VLOG(7)
...@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) { ...@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
} }
TEST(EagerPrim, TestFlags) { TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false); PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
} }
} // namespace prim } // namespace prim
......
...@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) { ...@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
} }
TEST(StaticPrim, TestFlags) { TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true); PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled()); ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false); PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled()); ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
} }
} // namespace prim } // namespace prim
......
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace prim { namespace prim {
StaticCompositeContext* StaticCompositeContext::static_composite_context_ = StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext(); new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_prim_ = false; thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -56,9 +56,18 @@ class StaticCompositeContext { ...@@ -56,9 +56,18 @@ class StaticCompositeContext {
return generator_->Generate(key); return generator_->Generate(key);
} }
void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; } void SetBwdPrimEnabled(bool enable_prim) { enable_bwd_prim_ = enable_prim; }
bool IsPrimEnabled() { return enable_prim_; } bool IsBwdPrimEnabled() { return enable_bwd_prim_; }
void SetFwdPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; }
bool IsFwdPrimEnabled() { return enable_fwd_prim_; }
void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim;
}
private: private:
StaticCompositeContext() StaticCompositeContext()
...@@ -66,7 +75,8 @@ class StaticCompositeContext { ...@@ -66,7 +75,8 @@ class StaticCompositeContext {
framework::BlockDesc* current_block_desc_; framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_; std::unique_ptr<UniqueNameGenerator> generator_;
static thread_local bool enable_prim_; static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
static StaticCompositeContext* static_composite_context_; static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext); DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
}; };
......
...@@ -19,12 +19,24 @@ ...@@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not"); PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
namespace paddle { namespace paddle {
namespace prim { namespace prim {
bool PrimCommonUtils::IsPrimEnabled() { bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsPrimEnabled(); return StaticCompositeContext::Instance().IsBwdPrimEnabled();
} }
void PrimCommonUtils::SetPrimEnabled(bool enable_prim) { void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim); return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
}
void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
}
void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
} }
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -18,8 +18,11 @@ namespace paddle { ...@@ -18,8 +18,11 @@ namespace paddle {
namespace prim { namespace prim {
class PrimCommonUtils { class PrimCommonUtils {
public: public:
static bool IsPrimEnabled(); static bool IsBwdPrimEnabled();
static void SetPrimEnabled(bool enabled); static void SetBwdPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
}; };
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> { ...@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
using paddle::distributed::DependType;
using paddle::distributed::DistModel; using paddle::distributed::DistModel;
using paddle::distributed::DistModelConfig; using paddle::distributed::DistModelConfig;
using paddle::distributed::DistModelDataBuf; using paddle::distributed::DistModelDataBuf;
...@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) { ...@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
.def( .def(
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>()); "run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::enum_<DependType>(*m, "DependType")
.value("NORMAL", DependType::NORMAL)
.value("LOOP", DependType::LOOP)
.value("STOP_LOOP", DependType::STOP_LOOP);
py::class_<TaskNode>(*m, "TaskNode") py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
int64_t,
int64_t,
int64_t>())
.def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>()) .def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t, .def(py::init<int32_t,
const std::vector<framework::OpDesc*>&, const std::vector<framework::OpDesc*>&,
int64_t, int64_t,
int64_t, int64_t,
int64_t,
int64_t>()) int64_t>())
.def("task_id", &TaskNode::task_id) .def("task_id", &TaskNode::task_id)
.def("add_upstream_task", &TaskNode::AddUpstreamTask) .def("add_upstream_task", &TaskNode::AddUpstreamTask)
...@@ -183,6 +183,7 @@ void BindFleetExecutor(py::module* m) { ...@@ -183,6 +183,7 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_pre_steps", &TaskNode::SetRunPerSteps) .def("set_run_pre_steps", &TaskNode::SetRunPerSteps)
.def("set_run_at_offset", &TaskNode::SetRunAtOffset) .def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType) .def("set_type", &TaskNode::SetType)
.def("set_cond_var_name", &TaskNode::SetCondVarName)
.def("role", &TaskNode::role) .def("role", &TaskNode::role)
.def("init", [](TaskNode& self) { self.Init(); }) .def("init", [](TaskNode& self) { self.Init(); })
.def("set_program", &TaskNode::SetProgram); .def("set_program", &TaskNode::SetProgram);
......
...@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) { ...@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str(); return oss.str();
}); });
m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled); m.def("__set_bwd_prim_enabled",
m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled); &paddle::prim::PrimCommonUtils::SetBwdPrimEnabled);
m.def("_is_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsBwdPrimEnabled);
m.def("__set_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetFwdPrimEnabled);
m.def("_is_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("set_num_threads", &platform::SetNumThreads); m.def("set_num_threads", &platform::SetNumThreads);
m.def("disable_signal_handler", &DisableSignalHandler); m.def("disable_signal_handler", &DisableSignalHandler);
...@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better // priority of GradCompOpMaker is less than GradCompMaker for better
// performance. // performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs; std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) { if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) { if (grad_comp_op_maker != nullptr) {
VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc, grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set, no_grad_set,
&grad_to_var, &grad_to_var,
......
...@@ -42,7 +42,7 @@ ...@@ -42,7 +42,7 @@
kernel : kernel :
func : add_grad func : add_grad
no_need_buffer : x, y no_need_buffer : x, y
composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis) composite : add_grad(x, y, out_grad, axis)
backward : add_double_grad backward : add_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
...@@ -390,7 +390,7 @@ ...@@ -390,7 +390,7 @@
param : [x, y] param : [x, y]
kernel : kernel :
func : divide_grad func : divide_grad
composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1) composite : divide_grad(x, y, out, out_grad, -1)
backward : divide_double_grad backward : divide_double_grad
- backward_op : dropout_grad - backward_op : dropout_grad
...@@ -1319,7 +1319,7 @@ ...@@ -1319,7 +1319,7 @@
kernel : kernel :
func : subtract_grad func : subtract_grad
no_need_buffer : x, y no_need_buffer : x, y
composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis) composite : subtract_grad(x, y, out_grad, axis)
backward : subtract_double_grad backward : subtract_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
......
...@@ -112,42 +112,6 @@ static void AppendActivation(const OneDNNContext& dev_ctx, ...@@ -112,42 +112,6 @@ static void AppendActivation(const OneDNNContext& dev_ctx,
} }
} }
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
static std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
template <typename T, template <typename T,
typename TForward, typename TForward,
typename TBackward = onednn_dummy_primitive, typename TBackward = onednn_dummy_primitive,
...@@ -1756,13 +1720,13 @@ static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x, ...@@ -1756,13 +1720,13 @@ static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
auto axis_set = std::set<int>(axis.begin(), axis.end()); auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size, axis_size,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"In an axis array, elements must be unique.")); "In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, PADDLE_ENFORCE_EQ(
in_rank,
axis_size, axis_size,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument("The input dimension's size "
"The input dimension's size "
"should be equal to the axis's size. " "should be equal to the axis's size. "
"But received dimension is %d, " "But received dimension is %d, "
"axis's size is %d", "axis's size is %d",
...@@ -1771,7 +1735,7 @@ static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x, ...@@ -1771,7 +1735,7 @@ static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size, axis_size,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1).")); "Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size()); std::vector<int64_t> new_x(x.size());
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -67,10 +67,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -67,10 +67,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64})}, phi::DataType::INT64})},
{"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, {"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"bitwise_and", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})}, {"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_or", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_xor", XPUKernelSet({phi::DataType::BOOL})},
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})}, {"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_allgather", {"c_allgather",
XPUKernelSet({phi::DataType::FLOAT16, XPUKernelSet({phi::DataType::FLOAT16,
...@@ -109,6 +106,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -109,6 +106,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip", XPUKernelSet({phi::DataType::FLOAT32})}, {"clip", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"coalesce_tensor", {"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"concat_grad", {"concat_grad",
...@@ -374,6 +373,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -374,6 +373,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL, phi::DataType::BOOL,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::FLOAT32})}, phi::DataType::FLOAT32})},
{"max_pool2d_with_index",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"max_pool2d_with_index_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"matmul_grad", {"matmul_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"matmul_v2_grad", {"matmul_v2_grad",
...@@ -435,7 +438,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -435,7 +438,10 @@ XPUOpMap& get_kl2_ops() {
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_sum",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT8,
phi::DataType::INT64})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})}, {"relu6", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad", {"relu_grad",
......
...@@ -146,17 +146,17 @@ PADDLE_DEFINE_EXPORTED_bool( ...@@ -146,17 +146,17 @@ PADDLE_DEFINE_EXPORTED_bool(
* CUDA related related FLAG * CUDA related related FLAG
* Name: FLAGS_gemm_use_half_precision_compute_type * Name: FLAGS_gemm_use_half_precision_compute_type
* Since Version: 2.4 * Since Version: 2.4
* Value Range: bool, default=true * Value Range: bool, default=false
* Example: * Example:
* Note: whether to use fp16 compute type when the input and output is fp16, * Note: whether to use fp16 compute type when the input and output is fp16,
* faster but it may loss precision. * faster but it may loss precision.
*/ */
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
gemm_use_half_precision_compute_type, gemm_use_half_precision_compute_type,
true, false,
"Whether to use fp16 compute type when the input and output is fp16, " "Whether to use fp16 compute type when the input and output is fp16, "
"faster but it may loss precision in most case. If true, the compute " "faster but it may loss precision in most case. If true, the compute "
"type will be set to fp32. Default is true."); "type will be set to fp16. Default is false.");
/** /**
* CUDA related FLAG * CUDA related FLAG
......
...@@ -4596,10 +4596,10 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -4596,10 +4596,10 @@ void UniqueRawInferMeta(const MetaTensor& x,
MetaTensor* index, MetaTensor* index,
MetaTensor* counts) { MetaTensor* counts) {
if (!is_sorted) { if (!is_sorted) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(x.dims().size() == 1 || x.dims().size() == 0,
x.dims().size(), true,
1, phi::errors::InvalidArgument(
phi::errors::InvalidArgument("The Input(X) should be 1-D Tensor, " "The Input(X) should be 0-D or 1-D Tensor, "
"But now the dims of Input(X) is %d.", "But now the dims of Input(X) is %d.",
x.dims().size())); x.dims().size()));
out->set_dims(phi::make_ddim({-1})); out->set_dims(phi::make_ddim({-1}));
...@@ -4607,6 +4607,15 @@ void UniqueRawInferMeta(const MetaTensor& x, ...@@ -4607,6 +4607,15 @@ void UniqueRawInferMeta(const MetaTensor& x,
return; return;
} }
if (x.dims().size() == 0) {
PADDLE_ENFORCE_EQ(axis.empty(),
true,
phi::errors::InvalidArgument(
"The Input(X) with 0-D Tensor, axis must be None"
"But now the axis is %d.",
axis[0]));
}
if (axis.empty()) { if (axis.empty()) {
out->set_dims(phi::make_ddim({-1})); out->set_dims(phi::make_ddim({-1}));
if (return_inverse) { if (return_inverse) {
......
...@@ -43,14 +43,13 @@ enum class AlgorithmType { ...@@ -43,14 +43,13 @@ enum class AlgorithmType {
kConvForward = 1, kConvForward = 1,
kConvBackwardData = 2, kConvBackwardData = 2,
kConvBackwardFilter = 3, kConvBackwardFilter = 3,
kTranspose = 4,
#ifdef PADDLE_WITH_CUDNN_FRONTEND #ifdef PADDLE_WITH_CUDNN_FRONTEND
kConvForwardV8 = 4, kConvForwardV8 = 5,
kConvBackwardDataV8 = 5, kConvBackwardDataV8 = 6,
kConvBackwardFilterV8 = 6, kConvBackwardFilterV8 = 7,
kTranspose = 7,
kAlgorithmCount = 8 kAlgorithmCount = 8
#else #else
kTranspose = 4,
kAlgorithmCount = 5 kAlgorithmCount = 5
#endif #endif
}; };
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
...@@ -463,12 +463,17 @@ void DispatchConcatKernel(const phi::GPUContext& ctx, ...@@ -463,12 +463,17 @@ void DispatchConcatKernel(const phi::GPUContext& ctx,
constexpr IndexT MaxVecSize = 16 / sizeof(T); constexpr IndexT MaxVecSize = 16 / sizeof(T);
bool find_vecsize_flag = false; bool find_vecsize_flag = false;
IndexT dispatch_vec_size = 1; IndexT dispatch_vec_size = 1;
auto output_data = reinterpret_cast<std::uintptr_t>(output->data());
for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) { for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) {
for (IndexT idx = 0; idx < in_num + 1; idx++) { const IndexT mov_size = vec_size * sizeof(T);
for (IndexT idx = 1; idx < in_num + 1; idx++) {
auto input_data = reinterpret_cast<std::uintptr_t>(inputs_data[idx - 1]);
// Since input_cols[0] is 0, we need to jump. // Since input_cols[0] is 0, we need to jump.
const IndexT input_col = inputs_col[idx + 1] - inputs_col[idx]; const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1];
if (input_col % vec_size == 0) { if (input_col % vec_size == 0 && output_data % mov_size == 0 &&
if (idx == in_num - 1) { input_data % mov_size == 0) {
if (idx == in_num) {
find_vecsize_flag = true; find_vecsize_flag = true;
} }
} else { } else {
......
...@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, ...@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims()); out_tmp.Resize(out->dims());
out_tmp = *out; out_tmp = *out;
phi::SqueezeInferKernel<T, Context>(dev_ctx, out_tmp, {-1}, out); phi::Squeeze<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else { } else {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1], x_dim[x_dim_size - 1],
......
...@@ -19,37 +19,64 @@ ...@@ -19,37 +19,64 @@
namespace phi { namespace phi {
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims, void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
int new_size) { const std::vector<int64_t> &y_dims,
std::vector<int64_t> new_dims(new_size, 1); const std::vector<int64_t> &out_dims,
for (size_t i = 0; i < dims.size(); ++i) { std::vector<int64_t> *x_bd_dims,
new_dims[new_size - dims.size() + i] = dims[i]; std::vector<int64_t> *y_bd_dims,
std::vector<int64_t> *out_bd_dims,
bool trans_x,
bool trans_y) {
if (x_dims.size() == 1) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[1];
(*x_bd_dims)[x_bd_dims->size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[x_bd_dims->size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[x_bd_dims->size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[y_bd_dims->size() - 1] = y_dims[1];
(*y_bd_dims)[y_bd_dims->size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[y_bd_dims->size() - y_dims.size() + i] = y_dims[i];
}
}
for (size_t i = 0; i < x_bd_dims->size() - 2; ++i) {
(*out_bd_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
} }
int h_idx = trans_x ? x_bd_dims->size() - 1 : x_bd_dims->size() - 2;
int w_idx = trans_y ? y_bd_dims->size() - 2 : y_bd_dims->size() - 1;
return new_dims; (*out_bd_dims)[x_bd_dims->size() - 2] = (*x_bd_dims)[h_idx];
(*out_bd_dims)[y_bd_dims->size() - 1] = (*y_bd_dims)[w_idx];
} }
template <typename T> template <typename T>
void CalculateGradMatrixDims(const OneDNNContext &dev_ctx, void CalculateGradMatrixDims(const OneDNNContext &dev_ctx,
DenseTensor *dx_tmp, DenseTensor *dx_tmp,
DenseTensor *dy_tmp, DenseTensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims, std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) { std::vector<int64_t> *dy_bd_dims) {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) { for (size_t i = 0; i < dx_bd_dims->size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) { if ((*dx_bd_dims)[i] != (*dy_bd_dims)[i]) {
if (dx_dims[i] == 1) { if ((*dx_bd_dims)[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i]; (*dx_bd_dims)[i] = (*dy_bd_dims)[i];
} else { } else {
(*dy_bd_dims)[i] = dx_dims[i]; (*dy_bd_dims)[i] = (*dx_bd_dims)[i];
} }
} }
} }
dx_tmp->Resize(make_ddim((*dx_bd_dims))); dx_tmp->Resize(make_ddim(*dx_bd_dims));
dev_ctx.template Alloc<T>(dx_tmp); dev_ctx.template Alloc<T>(dx_tmp);
dy_tmp->Resize(make_ddim((*dy_bd_dims))); dy_tmp->Resize(make_ddim(*dy_bd_dims));
dev_ctx.template Alloc<T>(dy_tmp); dev_ctx.template Alloc<T>(dy_tmp);
} }
...@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, ...@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
const DenseTensor *dx_tmp, const DenseTensor *dx_tmp,
DenseTensor *dx, DenseTensor *dx,
const std::vector<int64_t> &dx_dims, const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) { const std::vector<int64_t> &x_dims) {
funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum, funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum,
0.0f, 0.0f,
0.0f, 0.0f,
...@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, ...@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
dx_tmp, dx_tmp,
dx, dx,
dx_dims); x_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx); auto dst_memory_p = handler.AcquireDstMemory(dx);
...@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, ...@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx, ...@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx,
size_t ndims = std::max(x_dims.size(), y_dims.size()); size_t ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max<size_t>(ndims, 3); ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
}
if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
if (dout_dims.size() != ndims) {
dout_dims = ExtendDimsWithOnes(dout_dims, ndims);
}
// in broadcasting scenario new memory is required because // in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims // reduce sum must be calculated upon broadcasted dims
DenseTensor dx_tmp, dy_tmp; DenseTensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims); std::vector<int64_t> dout_bd_dims(ndims, 1);
std::vector<int64_t> dy_bd_dims(y_dims); std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(x_dims,
y_dims,
dout_dims,
&x_bd_dims,
&y_bd_dims,
&dout_bd_dims,
transpose_x,
transpose_y);
std::vector<int64_t> dx_bd_dims(x_bd_dims);
std::vector<int64_t> dy_bd_dims(y_bd_dims);
CalculateGradMatrixDims<T>( CalculateGradMatrixDims<T>(
dev_ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims); dev_ctx, &dx_tmp, &dy_tmp, &dx_bd_dims, &dy_bd_dims);
if (transpose_x && transpose_y) { if (transpose_x && transpose_y) {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, true, true, &dx_tmp); dev_ctx, y, dout, y_bd_dims, dout_bd_dims, true, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, true, &dy_tmp); dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, true, &dy_tmp);
} else if (transpose_x) { } else if (transpose_x) {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, false, true, &dx_tmp); dev_ctx, y, dout, y_bd_dims, dout_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, false, false, &dy_tmp); dev_ctx, x, dout, x_bd_dims, dout_bd_dims, false, false, &dy_tmp);
} else if (transpose_y) { } else if (transpose_y) {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, false, &dx_tmp); dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, false, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, false, &dy_tmp); dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, false, &dy_tmp);
} else { } else {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, true, &dx_tmp); dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, true, false, &dy_tmp); dev_ctx, x, dout, x_bd_dims, dout_bd_dims, true, false, &dy_tmp);
} }
if (x_dims != dx_bd_dims) { if (x_bd_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput<T>( ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dx_tmp, dx, x_dims, vectorize(x.dims())); dev_ctx, &dx_tmp, dx, dx_bd_dims, x_bd_dims);
} else { } else {
*dx = std::move(dx_tmp); *dx = std::move(dx_tmp);
} }
if (y_dims != dy_bd_dims) { if (y_bd_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput<T>( ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dy_tmp, dy, y_dims, vectorize(y.dims())); dev_ctx, &dy_tmp, dy, dy_bd_dims, y_bd_dims);
} else { } else {
*dy = std::move(dy_tmp); *dy = std::move(dy_tmp);
} }
dx->set_mem_desc(x.mem_desc());
dx->Resize(x.dims()); dx->Resize(x.dims());
dx->set_mem_desc(x.mem_desc().reshape(vectorize(x.dims()))); dy->set_mem_desc(y.mem_desc());
dy->Resize(y.dims()); dy->Resize(y.dims());
dy->set_mem_desc(y.mem_desc().reshape(vectorize(y.dims())));
} }
template <typename T, typename Context> template <typename T, typename Context>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -27,7 +27,8 @@ void SumKernel(const Context& dev_ctx, ...@@ -27,7 +27,8 @@ void SumKernel(const Context& dev_ctx,
bool keep_dim, bool keep_dim,
DenseTensor* out) { DenseTensor* out) {
bool reduce_all = recompute_reduce_all(x, dims); bool reduce_all = recompute_reduce_all(x, dims);
SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out); SumRawKernel<T, Context>(
dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
} }
} // namespace phi } // namespace phi
...@@ -82,5 +83,8 @@ PD_REGISTER_KERNEL( ...@@ -82,5 +83,8 @@ PD_REGISTER_KERNEL(
#endif #endif
#if defined(PADDLE_WITH_XPU) #if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(sum, XPU, ALL_LAYOUT, phi::SumKernel, float) {} PD_REGISTER_KERNEL(
sum, XPU, ALL_LAYOUT, phi::SumKernel, float, int8_t, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#endif #endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
...@@ -25,11 +25,7 @@ void SqueezeInferKernel(const Context& dev_ctx, ...@@ -25,11 +25,7 @@ void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const IntArray& axes, const IntArray& axes,
DenseTensor* out) { DenseTensor* out) {
auto x_dims = x.dims(); auto out_dims = out->dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims. out->Resize(out_dims); // copy will reset the dims.
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi { namespace phi {
...@@ -33,4 +34,14 @@ void SqueezeKernel(const Context& dev_ctx, ...@@ -33,4 +34,14 @@ void SqueezeKernel(const Context& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape); DenseTensor* xshape);
template <typename T, typename Context>
void Squeeze(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
MetaTensor meta_out(out);
SqueezeInferMeta(x, axes, &meta_out);
SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -19,51 +19,18 @@ ...@@ -19,51 +19,18 @@
namespace phi { namespace phi {
template <typename T, typename Context>
void BitwiseAndKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
ctx.template Alloc<T>(out);
int r = xpu::logical_and(
ctx.x_context(), x.data<T>(), y.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise and");
}
template <typename T, typename Context>
void BitwiseOrKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
ctx.template Alloc<T>(out);
int r = xpu::logical_or(
ctx.x_context(), x.data<T>(), y.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise or");
}
template <typename T, typename Context>
void BitwiseXorKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
ctx.template Alloc<T>(out);
int r = xpu::logical_xor(
ctx.x_context(), x.data<T>(), y.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise xor");
}
template <typename T, typename Context> template <typename T, typename Context>
void BitwiseNotKernel(const Context& ctx, void BitwiseNotKernel(const Context& ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out) { DenseTensor* out) {
using XPUDataType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
int r = int r = xpu::logical_not(ctx.x_context(),
xpu::logical_not(ctx.x_context(), x.data<T>(), out->data<T>(), x.numel()); reinterpret_cast<const XPUDataType*>(x.data<T>()),
reinterpret_cast<XPUDataType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise not"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise not");
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(bitwise_and, XPU, ALL_LAYOUT, phi::BitwiseAndKernel, bool) {}
PD_REGISTER_KERNEL(bitwise_or, XPU, ALL_LAYOUT, phi::BitwiseOrKernel, bool) {}
PD_REGISTER_KERNEL(bitwise_xor, XPU, ALL_LAYOUT, phi::BitwiseXorKernel, bool) {}
PD_REGISTER_KERNEL(bitwise_not, XPU, ALL_LAYOUT, phi::BitwiseNotKernel, bool) {} PD_REGISTER_KERNEL(bitwise_not, XPU, ALL_LAYOUT, phi::BitwiseNotKernel, bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ClipGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& min,
const Scalar& max,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
using XPUDataType = typename XPUTypeTrait<T>::Type;
int r =
xpu::clip_grad(ctx.x_context(),
reinterpret_cast<const XPUDataType*>(x.data<T>()),
reinterpret_cast<const XPUDataType*>(out_grad.data<T>()),
reinterpret_cast<XPUDataType*>(x_grad->data<T>()),
x.numel(),
min.to<T>(),
max.to<T>());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(
clip_grad, XPU, ALL_LAYOUT, phi::ClipGradKernel, float, int) {}
...@@ -104,7 +104,6 @@ void Pool2dGradKernel(const Context& ctx, ...@@ -104,7 +104,6 @@ void Pool2dGradKernel(const Context& ctx,
} }
if (pooling_type == "max") { if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
r = xpu::max_pool2d_grad<XPUType>( r = xpu::max_pool2d_grad<XPUType>(
ctx.x_context(), ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()), reinterpret_cast<const XPUType*>(x.data<T>()),
...@@ -142,6 +141,67 @@ void Pool2dGradKernel(const Context& ctx, ...@@ -142,6 +141,67 @@ void Pool2dGradKernel(const Context& ctx,
} }
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2dgrad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2dgrad");
} }
template <typename T, typename Context>
void MaxPool2dWithIndexGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& dout,
const std::vector<int>& kernel_size,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool global_pooling,
bool adaptive,
DenseTensor* dx) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(dx);
auto input_grad = reinterpret_cast<XPUType*>(dx->data<T>());
std::vector<int> ksize(kernel_size);
std::vector<int> strides(strides_t);
std::vector<int> paddings(paddings_t);
const auto* index_data = mask.data<int>();
PADDLE_ENFORCE_NOT_NULL(index_data,
errors::NotFound("index data should not be nullptr"));
PADDLE_ENFORCE_EQ(
ksize.size(),
2,
phi::errors::InvalidArgument("The Pool2d XPU OP only support 2 "
"dimension pooling!, but received "
"%d-dimension pool kernel size",
ksize.size()));
global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(dx->dims()[i + 2]);
}
}
const int n = dx->dims()[0];
const int c = dx->dims()[1];
const int in_h = dx->dims()[2];
const int in_w = dx->dims()[3];
auto output_grad = reinterpret_cast<const XPUType*>(dout.data<T>());
int r = xpu::Error_t::SUCCESS;
// pass a nullptr as input to XDNN is fine as long as index_data exists
r = xpu::max_pool2d_grad<XPUType>(ctx.x_context(),
/*input*/ nullptr,
/*output*/ nullptr,
index_data,
output_grad,
input_grad,
n,
c,
in_h,
in_w,
ksize,
strides,
paddings,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index_grad");
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(pool2d_grad, PD_REGISTER_KERNEL(pool2d_grad,
...@@ -150,3 +210,9 @@ PD_REGISTER_KERNEL(pool2d_grad, ...@@ -150,3 +210,9 @@ PD_REGISTER_KERNEL(pool2d_grad,
phi::Pool2dGradKernel, phi::Pool2dGradKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index_grad,
XPU,
ALL_LAYOUT,
phi::MaxPool2dWithIndexGradKernel,
float,
phi::dtype::float16) {}
...@@ -154,7 +154,72 @@ void Pool2dKernel(const Context& ctx, ...@@ -154,7 +154,72 @@ void Pool2dKernel(const Context& ctx,
} }
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2d"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2d");
} }
template <typename T, typename Context>
void MaxPool2dWithIndexKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool global_pooling,
bool adaptive,
DenseTensor* out,
DenseTensor* mask) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<int>(mask);
auto* index_data = mask->data<int>();
std::vector<int> ksize(kernel_size);
std::vector<int> strides(strides_t);
std::vector<int> paddings(paddings_t);
PADDLE_ENFORCE_EQ(ksize.size(),
2,
phi::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!"));
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1),
true,
phi::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(x.dims()[i + 2]);
}
}
const int n = x.dims()[0];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
auto input = reinterpret_cast<const XPUType*>(x.data<T>());
ctx.template Alloc<T>(out);
auto output = reinterpret_cast<XPUType*>(out->data<T>());
int r = xpu::Error_t::SUCCESS;
r = xpu::max_pool2d<XPUType>(ctx.x_context(),
input,
output,
index_data,
n,
c,
in_h,
in_w,
ksize,
strides,
paddings,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index");
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
pool2d, XPU, ALL_LAYOUT, phi::Pool2dKernel, float, phi::dtype::float16) {} pool2d, XPU, ALL_LAYOUT, phi::Pool2dKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index,
XPU,
ALL_LAYOUT,
phi::MaxPool2dWithIndexKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -46,4 +46,5 @@ void SumRawKernel(const Context& dev_ctx, ...@@ -46,4 +46,5 @@ void SumRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(sum_raw, XPU, ALL_LAYOUT, phi::SumRawKernel, float) {} PD_REGISTER_KERNEL(
sum_raw, XPU, ALL_LAYOUT, phi::SumRawKernel, float, int8_t, int64_t) {}
...@@ -26,6 +26,14 @@ void TransposeGradKernel(const Context& dev_ctx, ...@@ -26,6 +26,14 @@ void TransposeGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) { DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(x_grad); dev_ctx.template Alloc<T>(x_grad);
if (x_grad->numel() == 0) {
return;
}
if (axis.size() == 0) {
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
return;
}
std::vector<int> reversed_axis(axis); std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) { for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i; reversed_axis[axis[i]] = i;
......
...@@ -29,6 +29,7 @@ WHITE_LIST = { ...@@ -29,6 +29,7 @@ WHITE_LIST = {
'conv2d', 'conv2d',
'matmul', 'matmul',
'matmul_v2', 'matmul_v2',
'max_pool2d_with_index',
'mul', 'mul',
'fake_quantize_dequantize_abs_max', 'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max', 'fake_quantize_dequantize_moving_average_abs_max',
......
...@@ -18,7 +18,7 @@ from enum import Enum ...@@ -18,7 +18,7 @@ from enum import Enum
import numpy as np import numpy as np
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core, in_dygraph_mode from paddle.fluid import core, in_dygraph_mode
from paddle.fluid.data_feeder import check_type from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
...@@ -228,11 +228,9 @@ class AmpScaler: ...@@ -228,11 +228,9 @@ class AmpScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if self._found_inf: optimizer._set_auxiliary_var('found_inf', self._found_inf)
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
# uopdate the scale # uopdate the scale
...@@ -330,6 +328,9 @@ class AmpScaler: ...@@ -330,6 +328,9 @@ class AmpScaler:
param_grads_fp16, param_grads_fp16,
self._temp_found_inf_fp16, self._temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp16
)
if len(param_grads_bf16): if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, param_grads_bf16,
...@@ -338,6 +339,9 @@ class AmpScaler: ...@@ -338,6 +339,9 @@ class AmpScaler:
param_grads_bf16, param_grads_bf16,
self._temp_found_inf_bf16, self._temp_found_inf_bf16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_bf16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -346,6 +350,9 @@ class AmpScaler: ...@@ -346,6 +350,9 @@ class AmpScaler:
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32, self._temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp32
)
else: else:
if len(param_grads_fp16): if len(param_grads_fp16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
...@@ -354,6 +361,9 @@ class AmpScaler: ...@@ -354,6 +361,9 @@ class AmpScaler:
param_grads_fp16, param_grads_fp16,
self._temp_found_inf_fp16, self._temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_fp16
)
if len(param_grads_bf16): if len(param_grads_bf16):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_bf16, param_grads_bf16,
...@@ -361,6 +371,9 @@ class AmpScaler: ...@@ -361,6 +371,9 @@ class AmpScaler:
param_grads_bf16, param_grads_bf16,
self._temp_found_inf_bf16, self._temp_found_inf_bf16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, self._temp_found_inf_bf16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -368,11 +381,8 @@ class AmpScaler: ...@@ -368,11 +381,8 @@ class AmpScaler:
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32, self._temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf = ( self._found_inf, self._temp_found_inf_fp32
self._temp_found_inf_fp16
or self._temp_found_inf_bf16
or self._temp_found_inf_fp32
) )
optimizer_state["state"] = OptimizerState.UNSCALED optimizer_state["state"] = OptimizerState.UNSCALED
...@@ -761,11 +771,9 @@ class GradScaler(AmpScaler): ...@@ -761,11 +771,9 @@ class GradScaler(AmpScaler):
if optimizer_state["state"] is OptimizerState.INIT: if optimizer_state["state"] is OptimizerState.INIT:
self._unscale(optimizer) self._unscale(optimizer)
if self._found_inf: optimizer._set_auxiliary_var('found_inf', self._found_inf)
self._cache_founf_inf = True
else:
optimizer.step() optimizer.step()
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
optimizer_state["state"] = OptimizerState.STEPPED optimizer_state["state"] = OptimizerState.STEPPED
......
...@@ -26,24 +26,24 @@ class TaskNode: ...@@ -26,24 +26,24 @@ class TaskNode:
self, self,
rank, rank,
max_run_times, max_run_times,
max_slot_times,
role=None, role=None,
node_type=None, node_type=None,
task_id=0, task_id=0,
ops=None, ops=None,
program=None, program=None,
lazy_initialize=False, lazy_initialize=False,
cond_var_name=None,
): ):
""" """
:param rank (int): Current rank of the task node. :param rank (int): Current rank of the task node.
:param max_run_times (int): The max run times of the task node. :param max_run_times (int): The max run times of the task node.
:param max_slot_times (int): The mas slot times of the task node.
:param role (int): The role of the task node. (Will be removed in the future) :param role (int): The role of the task node. (Will be removed in the future)
:param node_type (str): The type of the task node. :param node_type (str): The type of the task node.
:param task_id (int): The id of task node. :param task_id (int): The id of task node.
:param ops (list): A list of op.desc to init the task node. (Will be removed in the future) :param ops (list): A list of op.desc to init the task node. (Will be removed in the future)
:param program (Program): An instance of Program to init the task node. :param program (Program): An instance of Program to init the task node.
:param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later. :param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later.
:param cond_var_name (string): Indicate the cond var name of while.
""" """
assert (ops is not None) ^ ( assert (ops is not None) ^ (
program is not None program is not None
...@@ -54,10 +54,10 @@ class TaskNode: ...@@ -54,10 +54,10 @@ class TaskNode:
self.id = int(task_id) self.id = int(task_id)
self.rank = rank self.rank = rank
self.max_run_times = max_run_times self.max_run_times = max_run_times
self.max_slot_times = max_slot_times
self.node_type = node_type self.node_type = node_type
self.program = program self.program = program
self.lazy_initialize = lazy_initialize self.lazy_initialize = lazy_initialize
self.cond_var_name = cond_var_name
self.run_pre_steps = None self.run_pre_steps = None
self.run_at_offset = None self.run_at_offset = None
self.node = None self.node = None
...@@ -69,11 +69,18 @@ class TaskNode: ...@@ -69,11 +69,18 @@ class TaskNode:
role is not None and task_id is not None role is not None and task_id is not None
), "If init task node with ops, should provide `role` and `task_id`." ), "If init task node with ops, should provide `role` and `task_id`."
self.node = core.TaskNode( self.node = core.TaskNode(
role, ops, rank, task_id, max_run_times, max_slot_times role,
ops,
rank,
task_id,
max_run_times,
) )
else: else:
self.node = core.TaskNode( self.node = core.TaskNode(
program.desc, rank, self.id, max_run_times, max_slot_times program.desc,
rank,
self.id,
max_run_times,
) )
if self.node_type: if self.node_type:
self.node.set_type(self.node_type) self.node.set_type(self.node_type)
...@@ -85,7 +92,6 @@ class TaskNode: ...@@ -85,7 +92,6 @@ class TaskNode:
self.rank, self.rank,
self.id, self.id,
self.max_run_times, self.max_run_times,
self.max_slot_times,
) )
if self.node_type: if self.node_type:
self.node.set_type(self.node_type) self.node.set_type(self.node_type)
...@@ -93,10 +99,12 @@ class TaskNode: ...@@ -93,10 +99,12 @@ class TaskNode:
self.node.set_run_pre_steps(self.run_pre_steps) self.node.set_run_pre_steps(self.run_pre_steps)
if self.run_at_offset: if self.run_at_offset:
self.node.set_run_at_offset(self.run_at_offset) self.node.set_run_at_offset(self.run_at_offset)
if self.cond_var_name:
self.node.set_cond_var_name(self.cond_var_name)
for up in self.upstreams: for up in self.upstreams:
self.node.add_upstream_task(up[0], up[1]) self.node.add_upstream_task(up[0], up[1], up[2])
for down in self.downstreams: for down in self.downstreams:
self.node.add_downstream_task(down[0], down[1]) self.node.add_downstream_task(down[0], down[1], down[2])
self.lazy_initialize = False self.lazy_initialize = False
return self.node return self.node
...@@ -124,17 +132,21 @@ class TaskNode: ...@@ -124,17 +132,21 @@ class TaskNode:
else: else:
self.node.set_run_at_offset(offset) self.node.set_run_at_offset(offset)
def add_upstream_task(self, upstream, buffer_size=2): def add_upstream_task(
self, upstream, buffer_size=2, depend_type=core.DependType.NORMAL
):
if self.lazy_initialize: if self.lazy_initialize:
self.upstreams.append((upstream, buffer_size)) self.upstreams.append((upstream, buffer_size, depend_type))
else: else:
self.node.add_upstream_task(upstream, buffer_size) self.node.add_upstream_task(upstream, buffer_size, depend_type)
def add_downstream_task(self, downstream, buffer_size=2): def add_downstream_task(
self, downstream, buffer_size=2, depend_type=core.DependType.NORMAL
):
if self.lazy_initialize: if self.lazy_initialize:
self.downstreams.append((downstream, buffer_size)) self.downstreams.append((downstream, buffer_size, depend_type))
else: else:
self.node.add_downstream_task(downstream, buffer_size) self.node.add_downstream_task(downstream, buffer_size, depend_type)
def task_id(self): def task_id(self):
return self.id return self.id
...@@ -309,33 +321,28 @@ class FleetExecutorUtils: ...@@ -309,33 +321,28 @@ class FleetExecutorUtils:
return task_node_map return task_node_map
def construct_task_nodes_1f1b(self, program_map): def construct_task_nodes_1f1b(self, program_map):
max_slot_times = int(self.max_run_times - self.coord['pp_idx'])
cur_start_id = int(self.rank * self.num_of_functionality) cur_start_id = int(self.rank * self.num_of_functionality)
lr_task_node = TaskNode( lr_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["lr"], program=program_map["lr"],
task_id=cur_start_id, task_id=cur_start_id,
) )
fwd_task_node = TaskNode( fwd_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["fwd"], program=program_map["fwd"],
task_id=cur_start_id + 1, task_id=cur_start_id + 1,
) )
bwd_task_node = TaskNode( bwd_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["bwd"], program=program_map["bwd"],
task_id=cur_start_id + 2, task_id=cur_start_id + 2,
) )
opt_task_node = TaskNode( opt_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
program=program_map["opt"], program=program_map["opt"],
task_id=cur_start_id + 3, task_id=cur_start_id + 3,
) )
...@@ -354,12 +361,10 @@ class FleetExecutorUtils: ...@@ -354,12 +361,10 @@ class FleetExecutorUtils:
return task_id_to_rank return task_id_to_rank
def construct_task_nodes_1f1b_op_list(self, op_list_map): def construct_task_nodes_1f1b_op_list(self, op_list_map):
max_slot_times = int(self.max_run_times - self.coord['pp_idx'])
cur_start_id = int(self.rank * self.num_of_functionality) cur_start_id = int(self.rank * self.num_of_functionality)
lr_task_node = TaskNode( lr_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize.LRSched), role=int(OpRole.Optimize.LRSched),
ops=op_list_map["lr"], ops=op_list_map["lr"],
task_id=cur_start_id, task_id=cur_start_id,
...@@ -369,7 +374,6 @@ class FleetExecutorUtils: ...@@ -369,7 +374,6 @@ class FleetExecutorUtils:
fwd_task_node = TaskNode( fwd_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Forward), role=int(OpRole.Forward),
ops=op_list_map["fwd"], ops=op_list_map["fwd"],
task_id=cur_start_id + 1, task_id=cur_start_id + 1,
...@@ -378,7 +382,6 @@ class FleetExecutorUtils: ...@@ -378,7 +382,6 @@ class FleetExecutorUtils:
bwd_task_node = TaskNode( bwd_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Backward), role=int(OpRole.Backward),
ops=op_list_map["bwd"], ops=op_list_map["bwd"],
task_id=cur_start_id + 2, task_id=cur_start_id + 2,
...@@ -387,7 +390,6 @@ class FleetExecutorUtils: ...@@ -387,7 +390,6 @@ class FleetExecutorUtils:
opt_task_node = TaskNode( opt_task_node = TaskNode(
rank=self.rank, rank=self.rank,
max_run_times=self.max_run_times, max_run_times=self.max_run_times,
max_slot_times=max_slot_times,
role=int(OpRole.Optimize), role=int(OpRole.Optimize),
ops=op_list_map["opt"], ops=op_list_map["opt"],
task_id=cur_start_id + 3, task_id=cur_start_id + 3,
...@@ -471,7 +473,6 @@ def origin(program, rank): ...@@ -471,7 +473,6 @@ def origin(program, rank):
rank=rank, rank=rank,
node_type="Compute", node_type="Compute",
max_run_times=1, max_run_times=1,
max_slot_times=1,
) )
task_id_to_rank = {task_node.task_id(): rank} task_id_to_rank = {task_node.task_id(): rank}
return [task_node.task_node()], task_id_to_rank return [task_node.task_node()], task_id_to_rank
...@@ -41,11 +41,9 @@ class HybridParallelGradScaler: ...@@ -41,11 +41,9 @@ class HybridParallelGradScaler:
optimize_ops, params_grads = (None, None) optimize_ops, params_grads = (None, None)
if self._found_inf: optimizer._set_auxiliary_var('found_inf', self._found_inf)
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs) optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if self._use_dynamic_loss_scaling: if self._use_dynamic_loss_scaling:
self._update() self._update()
......
...@@ -19,10 +19,10 @@ from types import MethodType ...@@ -19,10 +19,10 @@ from types import MethodType
import numpy as np import numpy as np
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.common_ops_import import dygraph_only from paddle.common_ops_import import dygraph_only
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.framework import core
from paddle.nn import clip from paddle.nn import clip
...@@ -231,6 +231,9 @@ def GroupShardedScaler(scaler): ...@@ -231,6 +231,9 @@ def GroupShardedScaler(scaler):
param_grads_fp16, param_grads_fp16,
temp_found_inf_fp16, temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -238,15 +241,17 @@ def GroupShardedScaler(scaler): ...@@ -238,15 +241,17 @@ def GroupShardedScaler(scaler):
param_grads_fp32, param_grads_fp32,
temp_found_inf_fp32, temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp32
)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 self._found_inf = self._found_inf.cast("int32")
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = self._found_inf.cast("bool")
scaler._unscale = MethodType(unscale_method, scaler) scaler._unscale = MethodType(unscale_method, scaler)
return scaler return scaler
......
...@@ -17,7 +17,7 @@ from types import MethodType ...@@ -17,7 +17,7 @@ from types import MethodType
import numpy as np import numpy as np
import paddle import paddle
from paddle import _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.framework import core from paddle.framework import core
...@@ -73,6 +73,9 @@ def distributed_scaler(scaler): ...@@ -73,6 +73,9 @@ def distributed_scaler(scaler):
param_grads_fp16, param_grads_fp16,
temp_found_inf_fp16, temp_found_inf_fp16,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp16
)
if len(param_grads_fp32): if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale( _legacy_C_ops.check_finite_and_unscale(
param_grads_fp32, param_grads_fp32,
...@@ -80,17 +83,19 @@ def distributed_scaler(scaler): ...@@ -80,17 +83,19 @@ def distributed_scaler(scaler):
param_grads_fp32, param_grads_fp32,
temp_found_inf_fp32, temp_found_inf_fp32,
) )
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp32
)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0 self._found_inf = self._found_inf.cast("int32")
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
# TODO(shenliang03) Since dp allreduce in the optimizer is # TODO(shenliang03) Since dp allreduce in the optimizer is
# after the gradscaler, check_finite needs to synchronize global # after the gradscaler, check_finite needs to synchronize global
# information. In the future, we should use check_group to speed. # information. In the future, we should use check_group to speed.
paddle.distributed.all_reduce( paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.MAX, group=None self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
) )
self._found_inf = is_found_inf.numpy()[0] self._found_inf = self._found_inf.cast("bool")
# Only data_parallel doesn't need to modify scaler # Only data_parallel doesn't need to modify scaler
fleet_env = fleet.fleet fleet_env = fleet.fleet
......
...@@ -1275,6 +1275,8 @@ def fftfreq(n, d=1.0, dtype=None, name=None): ...@@ -1275,6 +1275,8 @@ def fftfreq(n, d=1.0, dtype=None, name=None):
# Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # Tensor(shape=[5], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001]) # [ 0. , 0.40000001, 0.80000001, -0.80000001, -0.40000001])
""" """
if d * n == 0:
raise ValueError("d or n should not be 0.")
dtype = paddle.framework.get_default_dtype() dtype = paddle.framework.get_default_dtype()
val = 1.0 / (n * d) val = 1.0 / (n * d)
......
...@@ -1493,14 +1493,15 @@ def _append_backward_ops_( ...@@ -1493,14 +1493,15 @@ def _append_backward_ops_(
# remove some backward ops # remove some backward ops
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem # TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
if not core.is_prim_enabled(): if not core._is_bwd_prim_enabled():
not_need_ops = _find_not_need_ops( not_need_ops = _find_not_need_ops(
grad_op_descs, ops, input_grad_names_set grad_op_descs, ops, input_grad_names_set
) )
grad_op_descs = [ grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
] ]
else:
logging.debug("Runing backward composite and disable find_not_need_ops")
# append op_desc in grad_op_descs to target_block # append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName() op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
......
...@@ -98,10 +98,12 @@ def fused_embedding_seq_pool( ...@@ -98,10 +98,12 @@ def fused_embedding_seq_pool(
.. code-block:: python .. code-block:: python
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
paddle.enable_static()
dict_size = 20 dict_size = 20
data_t = fluid.layers.data( data_t = paddle.static.data(
name='word', shape=[1], dtype='int64', lod_level=1) name='word', shape=[-1, 1], dtype='int64', lod_level=1)
padding_idx = np.random.randint(1, 10) padding_idx = np.random.randint(1, 10)
out = fluid.contrib.fused_embedding_seq_pool( out = fluid.contrib.fused_embedding_seq_pool(
input=data_t, input=data_t,
...@@ -305,11 +307,13 @@ def multiclass_nms2( ...@@ -305,11 +307,13 @@ def multiclass_nms2(
import paddle.fluid as fluid import paddle.fluid as fluid
boxes = fluid.layers.data(name='bboxes', shape=[81, 4], import paddle
paddle.enable_static()
boxes = paddle.static.data(name='bboxes', shape=[-1, 81, 4],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
scores = fluid.layers.data(name='scores', shape=[81], scores = paddle.static.data(name='scores', shape=[-1, 81],
dtype='float32', lod_level=1) dtype='float32', lod_level=1)
out, index = fluid.layers.multiclass_nms2(bboxes=boxes, out, index = fluid.contrib.layers.multiclass_nms2(bboxes=boxes,
scores=scores, scores=scores,
background_label=0, background_label=0,
score_threshold=0.5, score_threshold=0.5,
...@@ -501,7 +505,9 @@ def shuffle_batch(x, seed=None): ...@@ -501,7 +505,9 @@ def shuffle_batch(x, seed=None):
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
x = fluid.layers.data(name="x", shape=[-1, 4]) import paddle
paddle.enable_static()
x = paddle.static.data(name="x", shape=[-1, 4])
out = fluid.contrib.layers.shuffle_batch(x) out = fluid.contrib.layers.shuffle_batch(x)
""" """
helper = LayerHelper('shuffle_batch', **locals()) helper = LayerHelper('shuffle_batch', **locals())
...@@ -1313,7 +1319,7 @@ def _pull_box_extended_sparse(input, size, extend_size=64, dtype='float32'): ...@@ -1313,7 +1319,7 @@ def _pull_box_extended_sparse(input, size, extend_size=64, dtype='float32'):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.data(name='sequence', shape=[1], dtype='int64', lod_level=1) data = paddle.static.data(name='sequence', shape=[-1, 1], dtype='int64', lod_level=1)
emb, emb_ex = fluid.contrib.layers._pull_box_extended_sparse(input=data, size=8, extend_size=128) emb, emb_ex = fluid.contrib.layers._pull_box_extended_sparse(input=data, size=8, extend_size=128)
""" """
helper = LayerHelper('pull_box_extended_sparse', **locals()) helper = LayerHelper('pull_box_extended_sparse', **locals())
...@@ -1438,15 +1444,14 @@ def correlation( ...@@ -1438,15 +1444,14 @@ def correlation(
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
x1 = fluid.layers.data(name='x1', paddle.enable_static()
shape=x_shape, x1 = paddle.static.data(name='x1',
dtype=x_type, shape=[2,3,4,5],
append_batch_size=False) dtype="float32")
x2 = fluid.layers.data(name='x2', x2 = paddle.static.data(name='x2',
shape=x_shape, shape=[2,3,4,5],
dtype=x_type, dtype="float32")
append_batch_size=False)
out = fluid.contrib.correlation( out = fluid.contrib.correlation(
...@@ -1555,8 +1560,8 @@ def fused_bn_add_act( ...@@ -1555,8 +1560,8 @@ def fused_bn_add_act(
# required: gpu # required: gpu
def build_program(main_program, startup_program): def build_program(main_program, startup_program):
with fluid.program_guard(main_program, startup_program): with fluid.program_guard(main_program, startup_program):
x = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32') x = paddle.static.data(name='x', shape=[-1, 1, 28, 28], dtype='float32')
y = fluid.layers.data(name="y", shape=[1], dtype='int64') y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64')
conv1_1 = paddle.static.nn.conv2d( conv1_1 = paddle.static.nn.conv2d(
input=x, input=x,
filter_size=3, filter_size=3,
......
...@@ -85,20 +85,20 @@ class TestCorrelationOp(unittest.TestCase): ...@@ -85,20 +85,20 @@ class TestCorrelationOp(unittest.TestCase):
np.set_printoptions(threshold=np.inf) np.set_printoptions(threshold=np.inf)
x_shape = (2, 10, 3, 3) x_shape = (2, 10, 3, 3)
x_type = 'float32' x_type = 'float32'
x1 = fluid.layers.data( x1 = paddle.static.data(
name='x1', name='x1',
shape=x_shape, shape=x_shape,
dtype=x_type, dtype=x_type,
append_batch_size=False,
stop_gradient=False,
) )
x2 = fluid.layers.data( x1.desc.set_need_check_feed(False)
x1.stop_gradient = False
x2 = paddle.static.data(
name='x2', name='x2',
shape=x_shape, shape=x_shape,
dtype=x_type, dtype=x_type,
append_batch_size=False,
stop_gradient=False,
) )
x2.desc.set_need_check_feed(False)
x2.stop_gradient = False
x1_np = np.random.randn(2, 3, 4, 5).astype(x_type) x1_np = np.random.randn(2, 3, 4, 5).astype(x_type)
x2_np = np.random.randn(2, 3, 4, 5).astype(x_type) x2_np = np.random.randn(2, 3, 4, 5).astype(x_type)
......
...@@ -110,10 +110,10 @@ def train(net_type, use_cuda, save_dirname, is_local): ...@@ -110,10 +110,10 @@ def train(net_type, use_cuda, save_dirname, is_local):
train_program.random_seed = 123 train_program.random_seed = 123
startup_prog.random_seed = 456 startup_prog.random_seed = 456
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
images = fluid.layers.data( images = paddle.static.data(
name='pixel', shape=data_shape, dtype='float32' name='pixel', shape=[-1] + data_shape, dtype='float32'
) )
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = paddle.static.data(name='label', shape=[-1, 1], dtype='int64')
if net_type == "vgg": if net_type == "vgg":
print("train vgg net") print("train vgg net")
...@@ -444,11 +444,11 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): ...@@ -444,11 +444,11 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
start_prog = paddle.static.Program() start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog): with paddle.static.program_guard(main_prog, start_prog):
with paddle.fluid.unique_name.guard(): with paddle.fluid.unique_name.guard():
image = fluid.layers.data( image = paddle.static.data(
name='image', shape=[3, 224, 224], dtype='float32' name='image', shape=[-1, 3, 224, 224], dtype='float32'
) )
label = fluid.layers.data( label = paddle.static.data(
name='label', shape=[1], dtype='int64' name='label', shape=[-1, 1], dtype='int64'
) )
py_reader = fluid.io.DataLoader.from_generator( py_reader = fluid.io.DataLoader.from_generator(
feed_list=[image, label], feed_list=[image, label],
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册