提交 3ebc0f73 编写于 作者: W wangruting

fix_conflict

......@@ -7,16 +7,20 @@ set(XPU_PROJECT "extern_xpu")
set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_BASE_DATE "20230114")
set(XPU_XCCL_BASE_VERSION "1.0.7")
if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.su.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20230110")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/${XPU_BASE_DATE}")
else()
set(XPU_BASE_URL "${XPU_BASE_URL}")
endif()
set(XPU_XCCL_BASE_URL
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/1.0.6")
"https://klx-sdk-release-public.su.bcebos.com/xccl/release/${XPU_XCCL_BASE_VERSION}"
)
if(WITH_AARCH64)
set(XPU_XRE_DIR_NAME "xre-kylin_aarch64")
......
......@@ -321,8 +321,7 @@ endif()
if(WITH_GPU)
if(${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0
OR (${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.6
AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.8))
OR (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.6))
include(external/cub) # download cub
list(APPEND third_party_deps extern_cub)
endif()
......
......@@ -2,18 +2,26 @@ add_subdirectory(auto_parallel)
add_subdirectory(collective)
add_subdirectory(fleet_executor)
if(WITH_PYTHON)
py_proto_compile(pslib_py_proto SRCS ps.proto)
py_proto_compile(ps_py_proto SRCS the_one_ps.proto)
add_custom_target(
ps_py_proto_init ALL
COMMAND ${CMAKE_COMMAND} -E make_directory
${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto)
add_dependencies(ps_py_proto ps_py_proto_init)
set(PSLIB_PROTO_DSTPATH
"${PADDLE_SOURCE_DIR}/python/paddle/fluid/incubate/fleet/parameter_server/pslib/"
)
if(NOT WIN32)
add_custom_command(
TARGET ps_py_proto
POST_BUILD
COMMAND mv the_one_ps_pb2.py
${PADDLE_BINARY_DIR}/python/paddle/distributed/fleet/proto/)
add_custom_command(
TARGET pslib_py_proto
POST_BUILD
COMMAND mv ps_pb2.py "${PSLIB_PROTO_DSTPATH}")
else()
string(
REPLACE "/" "\\" fleet_proto_dstpath
......@@ -25,7 +33,15 @@ if(WITH_PYTHON)
COMMENT
"Copy generated python the_one_ps_pb2 into directory ${fleet_proto_dstpath}."
)
string(REPLACE "/" "\\" PSLIB_PROTO_DSTPATH "${PSLIB_PROTO_DSTPATH}")
add_custom_command(
TARGET pslib_py_proto
POST_BUILD
COMMAND copy /Y ps_pb2.py ${PSLIB_PROTO_DSTPATH})
endif()
message(
STATUS
"Copy generated python ps_pb2.py into directory ${PSLIB_PROTO_DSTPATH}")
endif()
if(WITH_RPC)
......
......@@ -352,41 +352,17 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupBKCL::Reduce(
const phi::DenseTensor& input,
BKCLContext_t comm,
const XPUStream& stream) {
phi::DenseTensor output_t;
paddle::framework::TensorCopy(*output, platform::XPUPlace(), &output_t);
const auto& place = input.place();
auto* calc_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(place));
switch (input.dtype()) {
case phi::DataType::FLOAT32:
calc_ctx->template Alloc<float>(&output_t);
break;
case phi::DataType::FLOAT16:
calc_ctx->template Alloc<float16>(&output_t);
break;
case phi::DataType::INT32:
calc_ctx->template Alloc<int>(&output_t);
break;
default:
VLOG(0) << "Error: type " << input.dtype() << " not supported for "
<< GetBackendName();
break;
}
int ret =
bkcl_all_reduce(comm,
input.data(),
output_t.data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
stream);
if (rank_ == opts.root_rank) {
*output = output_t;
}
return ret;
return bkcl_reduce(comm,
input.data(),
output->data(),
input.numel(),
platform::ToBKCLDataType(
framework::TransToProtoVarType(input.type())),
ToBKCLRedType(opts.reduce_op),
opts.root_rank,
stream);
},
CommType::ALLREDUCE,
CommType::REDUCE,
sync_op,
use_calc_stream);
}
......
......@@ -36,6 +36,7 @@ cc_library(
interceptor.cc
compute_interceptor.cc
amplifier_interceptor.cc
cond_interceptor.cc
source_interceptor.cc
sink_interceptor.cc
message_service.cc
......@@ -66,6 +67,8 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(
amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS
${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(
......
......@@ -33,6 +33,7 @@ USE_INTERCEPTOR(Source);
USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);
USE_INTERCEPTOR(Sink);
USE_INTERCEPTOR(Cond);
void Carrier::Init(
int64_t rank,
......@@ -96,29 +97,30 @@ void Carrier::CopyParameters(
int microbatch_id,
const framework::ProgramDesc& program,
const std::vector<std::string>& inference_root_scope_vars) {
auto& global_block = program.Block(0);
std::map<std::string, int> inference_root_scope_var_map;
for (auto var_name : inference_root_scope_vars) {
inference_root_scope_var_map.insert({var_name, 1});
}
for (auto& var : global_block.AllVars()) {
std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end();
if (force_root) {
VLOG(4) << var_name << " will be forced to be created in the root scope.";
}
if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
for (size_t i = 0; i < program.Size(); ++i) {
for (auto& var : program.Block(i).AllVars()) {
std::string var_name = var->Name();
bool force_root = inference_root_scope_var_map.find(var_name) !=
inference_root_scope_var_map.end();
if (force_root) {
VLOG(4) << var_name
<< " will be forced to be created in the root scope.";
}
if ((var->Persistable() || force_root) && microbatch_id == 0) {
auto* ptr = root_scope_->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(5) << "Create persistable var: " << var->Name()
<< ", which pointer is " << ptr;
} else if (!var->Persistable()) {
auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name());
VLOG(5) << "Create variable " << var->Name() << " for microbatch "
<< microbatch_id << ", which pointer is " << ptr << ".";
InitializeVariable(ptr, var->GetType());
}
}
}
}
......
......@@ -125,6 +125,7 @@ void ComputeInterceptor::SendDataReadyToDownStream() {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
VLOG(3) << "ComputeInterceptor " << interceptor_id_
<< " Send data_is_ready msg to " << down_id
<< " in scope: " << cur_scope_id_;
......@@ -152,6 +153,7 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
InterceptorMessage reply_msg;
reply_msg.set_message_type(DATA_IS_USELESS);
reply_msg.set_scope_idx(cur_scope_id_);
Send(up_id, reply_msg);
}
}
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node)
: Interceptor(interceptor_id, node) {
PrepareDeps();
RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); });
}
void CondInterceptor::PrepareDeps() {
auto& upstream = node_->upstream();
auto& downstream = node_->downstream();
auto& id_to_dep_type = node_->id_to_dep_type();
for (const auto& up : upstream) {
if (id_to_dep_type.at(up.first) == DependType::NORMAL) {
normal_in_id_.insert(up.first);
}
}
for (const auto& down : downstream) {
if (id_to_dep_type.at(down.first) == DependType::NORMAL) {
normal_out_id_.insert(down.first);
} else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) {
stop_loop_id_ = down.first;
}
}
}
bool CondInterceptor::GetCondResult() {
PADDLE_ENFORCE_LT(cur_scope_id_,
microbatch_scopes_.size(),
platform::errors::InvalidArgument(
"Step out of range. There are %ld "
"microbatch_scopes, but recevice scope index %ld",
microbatch_scopes_.size(),
cur_scope_id_));
auto* cond_var =
microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var());
PADDLE_ENFORCE(cond_var,
platform::errors::NotFound(
"Condition variable %s not exists in scope %ld",
node_->cond_var(),
cur_scope_id_));
const auto& cond_tensor = cond_var->Get<phi::DenseTensor>();
bool res = false;
if (platform::is_gpu_place(cond_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::DenseTensor cpu_tensor;
framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait();
res = cpu_tensor.data<bool>()[0];
#endif
} else if (platform::is_cpu_place(cond_tensor.place())) {
res = cond_tensor.data<bool>()[0];
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport device for cond interceptor."));
}
return res;
}
void CondInterceptor::SendDataReady(int64_t down_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_READY);
ready_msg.set_scope_idx(cur_scope_id_);
Send(down_id, ready_msg);
}
void CondInterceptor::ReplyDataIsUseless(int64_t up_id) {
InterceptorMessage ready_msg;
ready_msg.set_message_type(DATA_IS_USELESS);
ready_msg.set_scope_idx(cur_scope_id_);
Send(up_id, ready_msg);
}
void CondInterceptor::Compute() {
cur_scope_id_ = ready_queue_.front();
ready_queue_.pop();
bool cond = GetCondResult();
VLOG(3) << "Cond interceptor get condition var " << node_->cond_var()
<< " with value " << cond;
if (cond) {
VLOG(3) << "Loop again in scope " << cur_scope_id_;
for (auto& down_id : normal_out_id_) {
SendDataReady(down_id);
}
} else {
VLOG(3) << "Finish loop in scope " << cur_scope_id_;
SendDataReady(stop_loop_id_);
}
}
void CondInterceptor::Run(const InterceptorMessage& msg) {
if (msg.message_type() == DATA_IS_READY) {
ready_queue_.push(msg.scope_idx());
Compute();
} else if (msg.message_type() == DATA_IS_USELESS) {
if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) {
for (auto& up_id : normal_in_id_) {
ReplyDataIsUseless(up_id);
}
// Gc the variable in while block
int64_t scope_id = msg.scope_idx();
if (gc_) {
VLOG(3) << "Release vars in while block in scope " << scope_id;
framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id],
node_->while_block_vars(),
gc_.get());
}
}
}
}
REGISTER_INTERCEPTOR(Cond, CondInterceptor);
} // namespace distributed
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
namespace paddle {
namespace distributed {
/* Condition Interceptor
* This is a special interceptor and only one condition op in the task node.
* This interceptor has two downstreams,
* 1. If the program result is true, select one of the downstreams, otherwise
* select another.
* 2. Used to implement while op in program.
*/
class CondInterceptor final : public Interceptor {
public:
CondInterceptor(int64_t interceptor_id, TaskNode* node);
private:
void PrepareDeps();
void Run(const InterceptorMessage& msg);
void Compute();
bool GetCondResult();
void SendDataReady(int64_t down_id);
void ReplyDataIsUseless(int64_t up_id);
std::queue<int64_t> ready_queue_;
int64_t cur_scope_id_;
std::set<int64_t> normal_in_id_;
std::set<int64_t> normal_out_id_;
int64_t stop_loop_id_;
};
} // namespace distributed
} // namespace paddle
......@@ -66,12 +66,11 @@ void FleetExecutor::Init(
"Fleet executor is inited with empty task node"));
// TODO(fleet_exe devs): the unused_vars should be got from run time graph
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto task_node : task_nodes) {
for (auto op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
for (const auto& desc : program_desc.Block(0).AllOps()) {
ops.emplace_back(framework::OpRegistry::CreateOp(*desc));
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {});
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
......@@ -107,6 +106,25 @@ void FleetExecutor::Init(
std::unordered_map<int64_t, TaskNode*> interceptor_id_to_task;
for (auto task_node : task_nodes) {
task_node->SetUnusedVars(unused_vars);
if (task_node->type() == "Cond") {
std::vector<std::string> while_block_vars;
std::vector<std::string> vars_in_parent;
std::vector<std::string> vars_in_sub;
for (auto& var : program_desc.Block(0).AllVars()) {
vars_in_parent.emplace_back(var->Name());
}
for (auto& var : program_desc.Block(1).AllVars()) {
vars_in_sub.emplace_back(var->Name());
}
std::sort(vars_in_parent.begin(), vars_in_parent.end());
std::sort(vars_in_sub.begin(), vars_in_sub.end());
std::set_difference(vars_in_sub.begin(),
vars_in_sub.end(),
vars_in_parent.begin(),
vars_in_parent.end(),
std::back_inserter(while_block_vars));
task_node->SetWhileBlockVars(while_block_vars);
}
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
}
......
......@@ -24,33 +24,14 @@ namespace {
using OperatorBase = TaskNode::OperatorBase;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums)
: program_(program),
rank_(rank),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
// Should be serially invoked, not thread-safe
// NOTE: when instantiate TaskNode with program, won't init task node
// immediately, since the provided program may be updated later (with
// high probability) by adding_feed_fetch_ops or by RuntimeGraph.
// So, delay the init part to the Init() function.
static int64_t task_node_cnt = 0;
task_id_ = task_node_cnt++;
}
TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: program_(program),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
max_run_times_(max_run_times) {
// TODO(liyurui): Will be removed when execute program is supported.
Init();
}
......@@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program,
TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank)
: program_(program), rank_(rank), task_id_(rank) {
max_run_times_ = 1;
max_slot_nums_ = 1;
LOG(INFO)
<< "Constructing TaskNode for DistModelInf. The TaskNode's id is: "
<< rank
......@@ -98,13 +78,11 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {
max_run_times_(max_run_times) {
if (op_descs.empty()) {
return;
}
......@@ -121,33 +99,35 @@ TaskNode::TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: ops_(ops),
role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
max_run_times_(max_run_times) {}
TaskNode::TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums)
int64_t max_run_times)
: role_(role),
rank_(rank),
task_id_(task_id),
max_run_times_(max_run_times),
max_slot_nums_(max_slot_nums) {}
max_run_times_(max_run_times) {}
bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) {
bool TaskNode::AddUpstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = upstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second;
}
bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) {
bool TaskNode::AddDownstreamTask(int64_t task_id,
int64_t buff_size,
DependType type) {
const auto& ret = downstream_.emplace(task_id, buff_size);
id_to_dep_type_.emplace(task_id, type);
return ret.second;
}
......
......@@ -14,8 +14,10 @@
#pragma once
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -29,38 +31,30 @@ class OpDesc;
} // namespace framework
namespace distributed {
enum class DependType { NORMAL, LOOP, STOP_LOOP };
class TaskNode final {
public:
using OperatorBase = paddle::framework::OperatorBase;
TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times);
TaskNode(int32_t role,
const std::vector<framework::OpDesc*>& op_descs,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
TaskNode(int32_t role,
const std::vector<framework::OperatorBase*>& ops,
int64_t rank,
int64_t task_id,
int64_t max_run_times,
int64_t max_slot_nums);
TaskNode(paddle::framework::ProgramDesc* program,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
TaskNode(paddle::framework::ProgramDesc* program, int64_t rank);
// TODO(liyurui): This will be the only constructor for task node
TaskNode(paddle::framework::ProgramDesc* program,
int64_t task_id,
int64_t rank,
int64_t max_run_times,
int64_t max_slot_nums);
int64_t max_run_times);
~TaskNode() = default;
void SetProgram(paddle::framework::ProgramDesc* program);
......@@ -69,11 +63,11 @@ class TaskNode final {
int64_t task_id() const { return task_id_; }
int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::string& cond_var() const { return cond_var_; }
const std::unordered_map<int64_t, int64_t>& upstream() const {
return upstream_;
}
......@@ -86,11 +80,20 @@ class TaskNode final {
const std::vector<std::unique_ptr<OperatorBase>>& unique_ops() const {
return ops_vec_;
}
const std::unordered_map<int64_t, DependType> id_to_dep_type() const {
return id_to_dep_type_;
}
const std::unordered_map<const OperatorBase*, std::vector<std::string>>&
unused_vars() const {
return unused_vars_;
}
const std::vector<std::string> while_block_vars() const {
return while_block_vars_;
}
void SetCondVarName(const std::string& cond_var_name) {
cond_var_ = cond_var_name;
}
void SetRunPerSteps(int64_t value);
void SetRunAtOffset(int64_t value);
void SetReplyUpPerSteps(int64_t value);
......@@ -101,10 +104,17 @@ class TaskNode final {
unused_vars) {
unused_vars_ = unused_vars;
}
void SetWhileBlockVars(const std::vector<std::string>& vars) {
while_block_vars_ = vars;
}
// upstream need buffs?
bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1);
bool AddUpstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
bool AddDownstreamTask(int64_t task_id,
int64_t buff_size = 1,
DependType type = DependType::NORMAL);
std::string DebugString() const;
private:
......@@ -115,16 +125,20 @@ class TaskNode final {
// task_id-->buff_size
std::unordered_map<int64_t, int64_t> upstream_;
std::unordered_map<int64_t, int64_t> downstream_;
// task_id-->type
std::unordered_map<int64_t, DependType> id_to_dep_type_;
framework::ProgramDesc* program_;
std::string cond_var_;
std::vector<std::unique_ptr<OperatorBase>> ops_vec_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_;
std::vector<std::string> while_block_vars_;
int32_t role_;
int64_t rank_;
int64_t task_id_;
int64_t max_run_times_;
int64_t max_slot_nums_;
int64_t run_per_steps_{1};
int64_t run_at_offset_{0};
......
......@@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) {
// FIXME: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0);
TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2); // role, ops, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 2);
TaskNode* sink = new TaskNode(0, SINK_ID, 2);
// source->a->b->sink
......
......@@ -37,8 +37,8 @@ TEST(ComputeInterceptor, Compute) {
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0);
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, 3);
TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* sink = new TaskNode(0, SINK_ID, 3);
// source->a->b->sink
......
......@@ -71,12 +71,12 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, 1); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1);
TaskNode* node_c = new TaskNode(0, 0, 2, 1);
TaskNode* node_d = new TaskNode(0, 0, 3, 1);
TaskNode* node_e = new TaskNode(0, 0, 4, 1);
TaskNode* node_f = new TaskNode(0, 0, 5, 1);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->e->f->sink
......
......@@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) {
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times
TaskNode* node_a =
new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0);
TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 3);
TaskNode* node_c = new TaskNode(0, 0, 2, 3);
TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps);
TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps);
// source->a->b->c->d->sink
......
......@@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id
TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
......
......@@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) {
msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "");
// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* source =
new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id
TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id
TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id
source->AddDownstreamTask(0, 1);
node_a->AddUpstreamTask(SOURCE_ID, 1);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
package paddle;
option cc_generic_services = true;
option cc_enable_arenas=true;
message PSParameter {
optional string worker_class = 1;
optional string server_class = 2;
optional string instance_class = 3;
optional string init_gflags = 4 [ default = "" ];
optional WorkerParameter worker_param = 101;
optional ServerParameter server_param = 102;
repeated DownpourTrainerParameter trainer_param = 301;
optional FsClientParameter fs_client_param = 501;
}
message WorkerParameter {
optional DownpourWorkerParameter downpour_worker_param = 1;
}
message ServerParameter {
optional DownpourServerParameter downpour_server_param = 1;
}
message DownpourWorkerParameter {
repeated TableParameter downpour_table_param = 1;
}
message DownpourTrainerParameter {
repeated DenseTableParameter dense_table = 1;
repeated SparseTableParameter sparse_table = 2;
optional int32 push_sparse_per_batch = 3;
optional int32 push_dense_per_batch = 4;
repeated string skip_op = 5;
repeated ProgramConfig program_config = 6;
}
message ProgramConfig {
required string program_id = 1;
repeated int32 push_sparse_table_id = 2;
repeated int32 push_dense_table_id = 3;
repeated int32 pull_sparse_table_id = 4;
repeated int32 pull_dense_table_id = 5;
}
message DenseTableParameter {
optional int32 table_id = 1;
repeated string dense_variable_name = 2;
repeated string dense_gradient_variable_name = 3;
optional int32 fea_dim = 4;
}
message SparseTableParameter {
optional int32 table_id = 1;
optional int32 feature_dim = 2;
repeated string slot_key = 3;
repeated string slot_value = 4;
repeated string slot_gradient = 5;
}
message DownpourServerParameter {
repeated TableParameter downpour_table_param = 1;
optional ServerServiceParameter service_param = 2;
}
message ServerServiceParameter {
optional string server_class = 1 [ default = "DownpourBrpcPsServer" ];
optional string client_class = 2 [ default = "DownpourBrpcPsClient" ];
optional string service_class = 3 [ default = "DownpourPsService"];
optional uint32 start_server_port = 4 [ default = 0 ]; //will find a avaliable port from it
optional uint32 server_thread_num = 5 [ default = 12 ];
}
enum TableType {
PS_SPARSE_TABLE = 0;
PS_DENSE_TABLE = 1;
}
message TableParameter {
optional uint64 table_id = 1;
optional string table_class = 2;
optional uint64 shard_num = 3 [ default = 1000 ];
optional TableAccessorParameter accessor = 4;
optional TableType type = 5;
optional bool compress_in_save = 6 [default = false];
//for cache model
optional bool enable_sparse_table_cache = 7 [default = true];
optional double sparse_table_cache_rate = 8 [default = 0.00055];
optional uint32 sparse_table_cache_file_num = 9 [default = 16];
optional double sparse_table_mem_cache_rate = 10 [default = 0.5];
}
message TableAccessorParameter {
optional string accessor_class = 1;
optional SparseSGDRuleParameter sparse_sgd_param = 2;
optional DenseSGDRuleParameter dense_sgd_param = 3;
optional uint32 fea_dim = 4 [default = 11];
optional uint32 embedx_dim = 5 [default = 8];
optional uint32 embedx_threshold = 6 [default = 10];
optional DownpourTableAccessorParameter downpour_accessor_param = 7;
repeated TableAccessorSaveParameter table_accessor_save_param = 8;
optional SparseCommonSGDRuleParameter sparse_commonsgd_param = 9;
optional SparseCommonSGDRuleParameter embed_sgd_param = 10;
optional SparseCommonSGDRuleParameter embedx_sgd_param = 11;
}
message DownpourTableAccessorParameter {
optional float nonclk_coeff = 1 [default = 0.1]; // to calculate show_click_score
optional float click_coeff = 2 [default = 1]; // to calculate show_click_score
optional float base_threshold = 3 [default = 1.5]; // show_click_score > base_threshold, this feature can be saved
optional float delta_threshold = 4 [default = 0.25]; // delta_score > delta_threshold, this feature can be saved
optional float delta_keep_days = 5 [default = 16]; // unseen_day < delta_keep_days, this feature can be saved
optional float show_click_decay_rate = 6 [default = 0.98]; // show/click will update to show/click * show_click_decay_rate after a day
optional float delete_threshold = 7 [default = 0.8]; // threshold to shrink a feasign
optional float delete_after_unseen_days = 8 [default = 30]; // unseen_day > delete_after_unseen_days, this feature will be delete in shrink_model
optional int32 ssd_unseenday_threshold = 9 [default = 1]; // threshold to save ssd
}
message TableAccessorSaveParameter {
optional uint32 param = 1;
optional string converter = 2;
optional string deconverter = 3;
}
enum PsCmdID {
PS_PULL_DENSE_TABLE = 0;
PS_PUSH_DENSE_TABLE = 1;
PS_PULL_SPARSE_TABLE = 2;
PS_PUSH_SPARSE_TABLE = 3;
PS_SHRINK_TABLE = 4;
PS_SAVE_ONE_TABLE = 5;
PS_SAVE_ALL_TABLE = 6;
PS_LOAD_ONE_TABLE = 7;
PS_LOAD_ALL_TABLE = 8;
PS_CLEAR_ONE_TABLE = 9;
PS_CLEAR_ALL_TABLE = 10;
PS_PUSH_DENSE_PARAM = 11;
PS_STOP_SERVER = 12;
PS_SAVE_ONE_CACHE_TABLE = 13;
PS_GET_CACHE_THRESHOLD = 14;
PS_CACHE_SHUFFLE = 15;
PS_COPY_TABLE = 16;
PS_COPY_TABLE_BY_FEASIGN = 17;
PS_PULL_SPARSE_TABLE_WITH_DEPENDENCY = 18;
PS_PUSH_SPARSE_TABLE_WITH_DEPENDENCY = 19;
PS_PRINT_TABLE_STAT = 20;
PS_SAVE_ONE_TABLE_PREFIX = 21;
PS_SAVE_MEM_CACHE_TABLE = 22;
//pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
//local_client2local_client cmd start from 200
PS_C2C_PULL_SPARSE_TABLE = 201;
}
message PsRequestMessage {
required uint32 cmd_id = 1;
optional uint32 table_id = 2;
repeated bytes params = 3;
optional int32 client_id = 4;
optional bytes data = 5;
};
message SparseSGDRuleParameter {
optional double learning_rate = 1 [default = 0.05];
optional double initial_g2sum = 2 [default = 3.0];
optional double initial_range = 3 [default = 0.0001];
repeated float weight_bounds = 4;
}
message SparseCommonSGDRuleParameter {
optional string name = 1;
optional SparseNaiveSGDRuleParameter naive = 2;
optional SparseAdagradSGDRuleParameter adagrad = 3;
optional SparseAdamSGDParameter adam = 4;
}
message SparseNaiveSGDRuleParameter {
optional double learning_rate = 1 [default = 0.05];
optional double initial_range = 2 [default = 0.0001];
repeated float weight_bounds = 3;
}
message SparseAdagradSGDRuleParameter {
optional double learning_rate = 1 [default = 0.05];
optional double initial_g2sum = 2 [default = 3.0];
optional double initial_range = 3 [default = 0.0001];
repeated float weight_bounds = 4;
}
message SparseAdamSGDParameter {
optional double learning_rate = 1 [default = 0.001];
optional double initial_range = 2 [default = 0.0001];
optional double beta1_decay_rate = 3 [default = 0.9];
optional double beta2_decay_rate = 4 [default = 0.999];
optional double ada_epsilon = 5 [default = 1e-08];
repeated float weight_bounds = 6;
}
message DenseSGDRuleParameter {
optional string name = 1;
optional AdamSGDParameter adam = 2;
optional NaiveSGDParameter naive = 3;
optional SummarySGDParameter summary = 4;
optional MovingAverageRuleParameter moving_average = 5;
}
message AdamSGDParameter {
optional double learning_rate = 1 [default = 5e-06]; // \u5B66\u4E60\u7387
optional double avg_decay_rate = 2 [default = 0.999993]; // avg_weight\u7684\u8870\u51CF\u7CFB\u6570
optional double ada_decay_rate = 3 [default = 0.9999];
optional double ada_epsilon = 4 [default = 1e-08];
optional double mom_decay_rate = 5 [default = 0.99];
}
message NaiveSGDParameter {
optional double learning_rate = 1 [default = 0.0002];
optional double avg_decay_rate = 2;
}
message SummarySGDParameter {
optional double summary_decay_rate = 1 [default = 0.999999]; // \u6743\u91CD\u7684\u8870\u51CF\u7CFB\u6570
}
message MovingAverageRuleParameter {
optional double momentum = 1;
}
message PsResponseMessage {
required int32 err_code = 1 [default = 0];
required string err_msg = 2 [default = ""];
optional bytes data = 3;
};
service PsService {
rpc service(PsRequestMessage) returns (PsResponseMessage);
};
message FsClientParameter {
enum FsApiType {
HDFS = 0;
AFS = 1;
}
optional FsApiType fs_type = 1 [default = HDFS];
optional string uri = 2; //such as afs://tianqi.afs.baidu.com:9902
optional string user = 3; //user_name to access fs
optional string passwd = 4; //password
optional int32 buffer_size = 5; //buffer for read/write
optional string hadoop_bin = 51;
optional string afs_conf = 101;
}
......@@ -1839,9 +1839,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
False if self.composite_func_info == {} else True
)
if is_composite_grad_api:
if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
......@@ -1982,6 +1982,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
backward_attrs_list = self.backward_attrs_list
backward_inplace_map = self.backward_inplace_map
indent = GetIndent(1)
need_gen_trace_backard_for_inplace = False
# Construct grad_api function args
# Order: TensorWrappers, GradTensors, Attributes
......@@ -2211,6 +2212,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
}} else {{
{inplace_str}
}}"""
need_gen_trace_backard_for_inplace = True
else:
inplace_for_grad_outs_str += inplace_str
......@@ -2259,7 +2261,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{
......@@ -2282,7 +2284,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
if (
len(next_grad_node_creation_str) > 0
or is_invoke_forward_api
or inplace_for_grad_outs_str != ''
or need_gen_trace_backard_for_inplace
):
compute_require_next_grad_str = f"{indent}bool trace_backward = egr::Controller::Instance().HasGrad() && create_graph;\n"
......
......@@ -618,7 +618,8 @@ if(WITH_PYTHON)
fleet_proto_init
pass_desc_py_proto
ps_py_proto
ps_py_proto_init)
ps_py_proto_init
pslib_py_proto)
if(NOT WIN32)
add_custom_command(
TARGET framework_py_proto
......
......@@ -140,7 +140,116 @@ struct FusedAttentionGradPattern : public PatternBase {
bool do_dropout, // dropout the softmax(qk) or not
bool add_residual); // add residual to out linear or not
// TODO(Yuang Liu): add backward pattern
// post layer norm grad
PATTERN_DECL_NODE(post_layer_norm_grad_op);
PATTERN_DECL_NODE(post_layer_norm_grad_scale);
PATTERN_DECL_NODE(post_layer_norm_grad_bias);
PATTERN_DECL_NODE(post_layer_norm_grad_mean);
PATTERN_DECL_NODE(post_layer_norm_grad_variance);
PATTERN_DECL_NODE(post_layer_norm_grad_x);
PATTERN_DECL_NODE(post_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(post_layer_norm_grad_x_grad);
// residual grad
PATTERN_DECL_NODE(residual_ele_add_grad_op);
PATTERN_DECL_NODE(residual_ele_add_grad_x);
PATTERN_DECL_NODE(residual_ele_add_grad_bias);
PATTERN_DECL_NODE(residual_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(residual_ele_add_grad_x_grad);
// out linear grad
PATTERN_DECL_NODE(out_linear_dropout_grad_op);
PATTERN_DECL_NODE(out_linear_dropout_grad_mask);
PATTERN_DECL_NODE(out_linear_dropout_grad_out);
PATTERN_DECL_NODE(out_linear_ele_add_grad_op);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias);
PATTERN_DECL_NODE(out_linear_ele_add_grad_x_grad);
PATTERN_DECL_NODE(out_linear_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_op);
PATTERN_DECL_NODE(out_linear_matmul_grad_x);
PATTERN_DECL_NODE(out_linear_matmul_grad_w);
PATTERN_DECL_NODE(out_linear_matmul_grad_x_grad);
PATTERN_DECL_NODE(out_linear_matmul_grad_w_grad);
// core attention grad
PATTERN_DECL_NODE(qkv_reshape_grad_op);
PATTERN_DECL_NODE(qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(qkv_reshape_grad_out);
PATTERN_DECL_NODE(qkv_transpose_grad_op);
PATTERN_DECL_NODE(qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(qkv_transpose_grad_out);
PATTERN_DECL_NODE(qkv_matmul_grad_op);
PATTERN_DECL_NODE(qkv_matmul_grad_x);
PATTERN_DECL_NODE(qkv_matmul_grad_w);
PATTERN_DECL_NODE(qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(qkv_matmul_grad_w_grad);
PATTERN_DECL_NODE(attn_dropout_grad_op);
PATTERN_DECL_NODE(attn_dropout_grad_mask);
PATTERN_DECL_NODE(attn_dropout_grad_out);
PATTERN_DECL_NODE(qk_softmax_grad_op);
PATTERN_DECL_NODE(qk_softmax_grad_fwd_out);
PATTERN_DECL_NODE(qk_softmax_grad_out);
PATTERN_DECL_NODE(add_mask_ele_add_grad_op);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x);
PATTERN_DECL_NODE(add_mask_ele_add_grad_bias);
PATTERN_DECL_NODE(add_mask_ele_add_grad_x_grad);
PATTERN_DECL_NODE(qk_scale_grad_op);
PATTERN_DECL_NODE(qk_scale_grad_out);
PATTERN_DECL_NODE(qk_matmul_grad_op);
PATTERN_DECL_NODE(qk_matmul_grad_x);
PATTERN_DECL_NODE(qk_matmul_grad_w);
PATTERN_DECL_NODE(qk_matmul_grad_x_grad);
PATTERN_DECL_NODE(qk_matmul_grad_w_grad);
// fuse qkv projection grad
PATTERN_DECL_NODE(fuse_qkv_split_grad_op); // concat op
PATTERN_DECL_NODE(fuse_qkv_split_grad_out);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_op);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_transpose_grad_out);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_op);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_x_shape);
PATTERN_DECL_NODE(fuse_qkv_reshape_grad_out);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_op);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_ele_add_grad_bias_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_op);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_x_grad);
PATTERN_DECL_NODE(fuse_qkv_matmul_grad_w_grad);
// pre layer norm grad
PATTERN_DECL_NODE(pre_layer_norm_grad_op);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias);
PATTERN_DECL_NODE(pre_layer_norm_grad_mean);
PATTERN_DECL_NODE(pre_layer_norm_grad_variance);
PATTERN_DECL_NODE(pre_layer_norm_grad_x);
PATTERN_DECL_NODE(pre_layer_norm_grad_scale_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_bias_grad);
PATTERN_DECL_NODE(pre_layer_norm_grad_x_grad);
// grad accumulation
PATTERN_DECL_NODE(grad_accumulation_sum_op);
PATTERN_DECL_NODE(grad_accumulation_out);
};
} // namespace patterns
......
......@@ -53,8 +53,13 @@ void MapOp2AnotherPass::ApplyImpl(ir::Graph* graph) const {
op_desc->SetAttr("shape", std::vector<int>{0, -1});
}
} else if (op_type == "depthwise_conv2d") {
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
auto groups = PADDLE_GET_CONST(int, op_desc->GetAttr("groups"));
if (groups > 1) {
#if CUDNN_VERSION >= 8100
op_desc->SetType(replaced_map[op_type]);
op_desc->SetAttr("use_cudnn", true);
#endif
}
}
op_desc->Flush();
++found_count;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
namespace framework {
namespace ir {
inline std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
inline std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
inline void SetActivationAttrs(paddle::framework::OpDesc* fused_op,
paddle::framework::OpDesc* act_op,
const std::string& act_type) {
if (fused_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(PADDLE_GET_CONST(bool, fused_op->GetAttr("use_mkldnn")),
phi::errors::PreconditionNotMet(
"oneDNN activation fuses require use_mkldnn=True"));
}
fused_op->SetAttr("use_mkldnn", true);
auto attr_map = GetAttributeMap(act_type);
for (const auto& attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fused_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fused_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fused_op->SetAttr("fuse_activation", act_type);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
namespace paddle {
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
std::vector<std::string> conv_types = {"fused_conv2d", "conv2d"};
for (auto& act_type : act_types) {
......@@ -40,7 +40,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
const std::string& conv_type,
std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(conv_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -62,28 +62,13 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_act_pattern);
OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
SetActivationAttrs(conv_op, activation->Op(), act_type);
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type =
PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
conv_op->SetOutput("Output", {activation_out->Name()});
IR_NODE_LINK_TO(conv, activation_out);
......@@ -105,7 +90,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
void ConvActivationMkldnnFusePass::FuseConvConcatAct(
Graph* graph, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("conv2d_concat_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -137,13 +122,13 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
return;
}
bool is_not_conv_mkldnn =
bool is_not_conv_onednn =
!(prev_op_nodes[0]->Op()->GetAttrIfExists<bool>("use_mkldnn"));
if ((prev_op_nodes[0]->Op()->Type() != "conv2d" &&
prev_op_nodes[0]->Op()->Type() != "fused_conv2d") ||
is_not_conv_mkldnn) {
LOG(WARNING) << "This fuse pass supports only conv2d(mkldnn) | "
"fused_conv2d(mkldnn) + activation.";
is_not_conv_onednn) {
LOG(WARNING) << "This fuse pass supports only conv2d(oneDNN) | "
"fused_conv2d(oneDNN) + activation.";
return;
}
}
......@@ -153,23 +138,8 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
if (conv_op->Type() == "conv2d") {
conv_op->SetType("fused_conv2d");
}
OpDesc* act_op = activation_op->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
act_type = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
conv_op->SetAttr("fuse_alpha", 0.0f);
conv_op->SetAttr("fuse_beta", 0.0f);
}
conv_op->SetAttr("fuse_activation", act_type);
SetActivationAttrs(conv_op, activation_op->Op(), act_type);
}
concat_op->Op()->SetOutput("Out", {activation_out->Name()});
......
......@@ -15,8 +15,8 @@
#include "paddle/fluid/framework/ir/mkldnn/elt_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
std::vector<std::string> elt_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul"};
......@@ -42,7 +42,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
const std::string &elt_type,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(elt_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -62,35 +62,8 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, elementwise_act_pattern);
auto *elementwise_op = elementwise->Op();
if (elementwise_op->HasAttr("use_mkldnn")) {
const std::string wo_elt_type =
"The " + elt_type; // Workaround for PP error message checking.
PADDLE_ENFORCE_EQ(
PADDLE_GET_CONST(bool, elementwise_op->GetAttr("use_mkldnn")),
true,
platform::errors::PreconditionNotMet(
wo_elt_type + "+Act fusion may happen only when oneDNN library "
"is used."));
}
auto *activation_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second,
activation_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate")))
elementwise_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
elementwise_op->SetAttr("fuse_activation", act_type);
elementwise_op->SetOutput("Out", {activation_out->Name()});
SetActivationAttrs(elementwise->Op(), activation->Op(), act_type);
elementwise->Op()->SetOutput("Out", {activation_out->Name()});
IR_OP_VAR_LINK(elementwise, activation_out);
GraphSafeRemoveNodes(g, {activation, elementwise_out});
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
namespace paddle {
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
for (auto act_type : act_types) FuseFCAct(graph, act_type);
}
......@@ -33,7 +33,7 @@ void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -50,35 +50,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
GET_IR_NODE_FROM_SUBGRAPH(act, activation, fc_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, activation_out, fc_act_pattern);
auto *fc_op = fc->Op();
auto *act_op = act->Op();
if (fc_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(
PADDLE_GET_CONST(bool, fc_op->GetAttr("use_mkldnn")),
platform::errors::PreconditionNotMet(
"The FC+Act fusion may happen only when oneDNN library "
"is used."));
}
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fc_op->SetAttr("fuse_activation", act_type);
}
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()});
SetActivationAttrs(fc->Op(), act->Op(), act_type);
fc->Op()->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out);
GraphSafeRemoveNodes(g, {act, fc_out});
......
......@@ -14,8 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/matmul_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h"
namespace paddle {
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
auto matmul_types = {"matmul", "matmul_v2"};
for (const auto& matmul_type : matmul_types)
......@@ -37,7 +37,7 @@ void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
void MatmulActivationMkldnnFusePass::FuseMatmulAct(
Graph* graph, const std::string& matmul_type, std::string& act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(matmul_type + "_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
......@@ -61,24 +61,8 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
GET_IR_NODE_FROM_SUBGRAPH(
activation_out, activation_out, matmul_act_pattern);
OpDesc* matmul_op = matmul->Op();
OpDesc* act_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
}
}
if (act_type == "gelu" && activation->Op()->HasAttr("approximate")) {
act_type =
PADDLE_GET_CONST(bool, activation->Op()->GetAttr("approximate"))
? "gelu_tanh"
: "gelu_erf";
}
matmul_op->SetAttr("fuse_activation", act_type);
matmul_op->SetOutput("Out", {activation_out->Name()});
SetActivationAttrs(matmul->Op(), activation->Op(), act_type);
matmul->Op()->SetOutput("Out", {activation_out->Name()});
IR_NODE_LINK_TO(matmul, activation_out);
GraphSafeRemoveNodes(graph, {activation, matmul_out});
......
......@@ -15,8 +15,8 @@
#include "paddle/fluid/framework/ir/mkldnn/softplus_activation_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/mkldnn/activation_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/string/pretty_log.h"
......@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail;
void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = phi::funcs::GetSupportedActivations();
auto act_types = GetSupportedActivations();
// Currently softplus can't be fused with hard_sigmoid
act_types.erase(
......@@ -42,7 +42,7 @@ void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
Graph *graph, const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
graph, phi::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("softplus_activation", graph);
GraphPatternDetector gpd;
......@@ -63,34 +63,8 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
GET_IR_NODE_FROM_SUBGRAPH(
activation, activation, softplus_activation_pattern);
auto *softplus_op = softplus->Op();
if (softplus_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE_EQ(
PADDLE_GET_CONST(bool, softplus_op->GetAttr("use_mkldnn")),
true,
platform::errors::PreconditionNotMet("The softplus + activation "
"fusion may happen only when "
"oneDNN library is used."));
}
auto *activation_op = activation->Op();
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && activation_op->HasAttr("approximate") &&
PADDLE_GET_CONST(bool, activation_op->GetAttr("approximate")))
softplus_op->SetAttr("fuse_activation", std::string("gelu_tanh"));
else
softplus_op->SetAttr("fuse_activation", act_type);
softplus_op->SetAttr("use_mkldnn", true);
softplus_op->SetOutput("Out", {activation_out->Name()});
SetActivationAttrs(softplus->Op(), activation->Op(), act_type);
softplus->Op()->SetOutput("Out", {activation_out->Name()});
IR_OP_VAR_LINK(softplus, activation_out);
GraphSafeRemoveNodes(g, {activation, softplus_out});
......
......@@ -144,8 +144,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Input(param_name);
for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
}
}
......@@ -153,8 +151,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Output(param_name);
for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
}
}
}
......@@ -166,48 +162,25 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
return deny_var_set;
}
std::unordered_set<std::string> OpTransInfo::GetIgnoreInplaceVarNames(
const OpDesc& op_desc) const {
if (!ignore_inplace_param_cond_.count(op_desc.Type())) {
return {};
}
std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs) {
std::unordered_set<std::string> all_inputs, all_outputs;
const auto& ignore_inplace_names =
ignore_inplace_param_cond_.at(op_desc.Type());
VLOG(4) << "We found ignore inplace param "
<< GetDebugInfo(ignore_inplace_names) << " in op [" << op_desc.Type()
<< "].";
std::unordered_set<std::string> ignore_inplace_set;
for (const auto& param_name : ignore_inplace_names) {
if (op_desc.HasOutput(param_name)) {
const auto& arg_names = op_desc.Output(param_name);
ignore_inplace_set.insert(arg_names.begin(), arg_names.end());
}
for (auto* var : cluster_inputs) {
all_inputs.insert(var->Name());
}
for (auto* var : cluster_outputs) {
all_outputs.insert(var->Name());
}
VLOG(4) << "All ignore inplace var names are "
<< GetDebugInfo(ignore_inplace_set);
return ignore_inplace_set;
}
bool OpTransInfo::IsInplaceOp(
const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const {
const auto& ignore_inplace_set = GetIgnoreInplaceVarNames(op_desc);
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0 && !deny_var_names.count(name) &&
!ignore_inplace_set.count(name)) {
VLOG(4) << "The argument " << name << " in op " << op_desc.Type()
<< " is a inplace op, skip!";
return true;
std::unordered_set<std::string> inplace_var_set;
for (const auto& var_name : all_inputs) {
if (all_outputs.count(var_name)) {
inplace_var_set.insert(var_name);
}
}
return false;
return inplace_var_set;
}
namespace {
......@@ -503,6 +476,14 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// initialize empty map for kMemOptVarInfoFromMainGraph attribute,
// it will be filled on the share_mem_opt_info_to_subgraph pass
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster_inputs, cluster_outputs));
VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
inplace_var_names.release());
return subgraph;
}
......@@ -594,7 +575,6 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) {
// Add the cinn launch op
framework::OpDesc cinn_op_desc;
......@@ -615,6 +595,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
cinn_op_desc.Flush();
auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc);
// Add new links from or to the cinn launch op node
......@@ -639,21 +620,15 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are
// cluster_outputs.
// Meanwhile, move all links of cluster to the cinn op.
void ReplaceSubGraphWithCinnOpNode(
const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) {
void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
int64_t compilation_key,
Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster,
cluster_inputs,
cluster_outputs,
compilation_key,
deny_var_set,
graph);
AddCinnOpToGraph(
cluster, cluster_inputs, cluster_outputs, compilation_key, graph);
// Remove the cinn subgraph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
}
......@@ -667,9 +642,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info;
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes());
auto teller = [&allow_ops, &deny_ops, &trans_info, &deny_var_set](
const Node* node) {
auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) {
const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node_name) != nullptr;
......@@ -679,10 +652,9 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node);
}
bool is_support =
registered && !trans_info.default_deny_ops().count(node_name) &&
!is_dynamic &&
(node->IsOp() && !trans_info.IsInplaceOp(*node->Op(), deny_var_set));
bool is_support = registered &&
!trans_info.default_deny_ops().count(node_name) &&
!is_dynamic;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (!allow_ops.empty()) {
......@@ -714,19 +686,23 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
return res;
};
std::unordered_set<std::string> skip_gc_var_names;
std::unordered_set<std::string> all_skip_gc_vars;
if (graph->Has(kSkipGcVarNames)) {
skip_gc_var_names =
all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
VLOG_IF(4, !all_skip_gc_vars.empty())
<< "All skip gc var names are: " << GetDebugInfo(all_skip_gc_vars);
}
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes());
VLOG_IF(4, !deny_var_set.empty())
<< "All deny var names are: " << GetDebugInfo(deny_var_set);
auto* cinn_compiler = CinnCompiler::GetInstance();
for (const auto& node_vec : clusters) {
// Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
auto deny_var_set = trans_info.GetDenyVarNames(cluster_set);
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set,
deny_var_set,
......@@ -734,7 +710,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
&cluster_outputs,
&cluster_internals,
is_inference_stage,
skip_gc_var_names);
all_skip_gc_vars);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
......@@ -747,8 +723,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_set, cluster_internals, cluster_inputs, cluster_outputs);
// Deliver the kSkipGcVarNames attr (if exists) to the subgraph
if (graph->Has(kSkipGcVarNames)) {
const auto& all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
auto& sub_skip_gc_vars =
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars;
......@@ -763,7 +737,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_outputs,
cluster_internals,
compilation_key,
deny_var_set,
graph);
}
}
......
......@@ -39,6 +39,7 @@ constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph";
constexpr char kSkipGcVarNames[] = "skip_gc_vars";
constexpr char kInplaceVarNames[] = "InplaceVars";
using Name2VarInfoMap =
std::unordered_map<std::string,
......@@ -67,11 +68,8 @@ class OpTransInfo {
std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const;
std::unordered_set<std::string> GetIgnoreInplaceVarNames(
const OpDesc& op_desc) const;
bool IsInplaceOp(const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const;
static std::unordered_set<std::string> GetInplaceVarNames(
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs);
private:
DyOpCondT dynamic_op_cond_;
......@@ -79,9 +77,6 @@ class OpTransInfo {
DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
DeParamCondT ignore_inplace_param_cond_{
{"batch_norm", {"MeanOut", "VarianceOut"}}};
std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
};
......
......@@ -258,17 +258,16 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
std::unordered_set<std::string> fetch_names;
fetch_names.reserve(fetch_var_names_.size());
std::for_each(
fetch_var_names_.begin(),
fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ(
var_model_to_program_map_.count(name),
1,
platform::errors::PreconditionNotMet(
"Cannot find %s in var_model_to_program_map_", name.c_str()));
fetch_names.insert(var_model_to_program_map_.at(name));
});
std::for_each(fetch_var_names_.begin(),
fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ(
var_map_.count(name),
1,
platform::errors::PreconditionNotMet(
"Cannot find %s in var_map_", name.c_str()));
fetch_names.insert(var_map_.at(name)->id);
});
return fetch_names;
}
......
......@@ -337,6 +337,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
pair.first != "X") {
continue;
}
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first == "Mask") {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
......@@ -381,6 +386,11 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
pair.first == "X" && dst_type == framework::proto::VarType::FP32) {
continue;
}
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first != "Mask" && dst_type == framework::proto::VarType::FP32) {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforwad") &&
dst_type == framework::proto::VarType::FP32) {
if (pair.first != "LnScale" && pair.first != "LnBias" &&
......@@ -428,6 +438,11 @@ NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
pair.first != "X") {
continue;
}
if ((op_type == "max_pool2d_with_index_grad" ||
op_type == "max_pool2d_with_index") &&
pair.first == "Mask") {
continue;
}
if ((op_type == "fused_attention" || op_type == "fused_feedforward")) {
if (pair.first == "LnScale" || pair.first == "LnBias" ||
pair.first == "Ln2Scale" || pair.first == "Ln2Bias" ||
......
......@@ -1609,6 +1609,51 @@ std::vector<std::string> AnalysisPredictor::GetOutputNames() {
return output_names;
}
std::map<std::string, std::vector<int64_t>>
AnalysisPredictor::GetOutputTensorShape() {
std::map<std::string, std::vector<int64_t>> output_shapes;
std::vector<std::string> names = GetOutputNames();
for (std::string name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var,
platform::errors::PreconditionNotMet(
"Output %s does not exist.", name));
output_shapes[name] = var->GetShape();
}
return output_shapes;
}
std::map<std::string, paddle_infer::DataType>
AnalysisPredictor::GetOutputTypes() {
std::map<std::string, paddle_infer::DataType> output_type;
std::vector<std::string> names = GetOutputNames();
for (const auto &name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::PreconditionNotMet(
"Output %s does not exist inference_program_.", name));
auto dtype = var->GetDataType();
if (dtype == paddle::framework::proto::VarType::FP32) {
output_type[name] = paddle_infer::DataType::FLOAT32;
} else if (dtype == paddle::framework::proto::VarType::FP16) {
output_type[name] = paddle_infer::DataType::FLOAT16;
} else if (dtype == paddle::framework::proto::VarType::INT64) {
output_type[name] = paddle_infer::DataType::INT64;
} else if (dtype == paddle::framework::proto::VarType::INT32) {
output_type[name] = paddle_infer::DataType::INT32;
} else if (dtype == paddle::framework::proto::VarType::UINT8) {
output_type[name] = paddle_infer::DataType::UINT8;
} else if (dtype == paddle::framework::proto::VarType::INT8) {
output_type[name] = paddle_infer::DataType::INT8;
} else {
PADDLE_THROW(paddle::platform::errors::Unimplemented(
"Unsupported data type `%s` when get output dtype ", dtype));
}
}
return output_type;
}
std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
const std::string &name) {
framework::Scope *scope;
......@@ -2477,6 +2522,10 @@ std::vector<std::string> Predictor::GetInputNames() {
return predictor_->GetInputNames();
}
std::map<std::string, std::vector<int64_t>> Predictor::GetInputTensorShape() {
return predictor_->GetInputTensorShape();
}
std::map<std::string, DataType> Predictor::GetInputTypes() {
return predictor_->GetInputTypes();
}
......@@ -2493,6 +2542,14 @@ std::unique_ptr<Tensor> Predictor::GetOutputHandle(const std::string &name) {
return predictor_->GetOutputTensor(name);
}
std::map<std::string, std::vector<int64_t>> Predictor::GetOutputTensorShape() {
return predictor_->GetOutputTensorShape();
}
std::map<std::string, DataType> Predictor::GetOutputTypes() {
return predictor_->GetOutputTypes();
}
bool Predictor::Run() { return predictor_->ZeroCopyRun(); }
std::unique_ptr<Predictor> Predictor::Clone(void *stream) {
......
......@@ -191,6 +191,18 @@ class AnalysisPredictor : public PaddlePredictor {
/// \return the map of input names and type
///
std::map<std::string, paddle_infer::DataType> GetInputTypes() override;
///
/// \brief Get all output names and their corresponding shapes
///
/// \return the map of output names and shapes
///
std::map<std::string, std::vector<int64_t>> GetOutputTensorShape() override;
///
/// \brief Get all output names and their corresponding type
///
/// \return the map of output names and type
///
std::map<std::string, paddle_infer::DataType> GetOutputTypes() override;
///
/// \brief Run the prediction engine
......
......@@ -106,6 +106,8 @@ TEST(AnalysisPredictor, analysis_on) {
ASSERT_EQ(predictor->scope_->parent(), nullptr);
ASSERT_EQ(predictor->sub_scope_->parent(), predictor->scope_.get());
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
ASSERT_EQ(predictor->GetOutputTypes().size(), 1UL);
ASSERT_EQ(predictor->GetOutputTensorShape().size(), 1UL);
// 2. Dummy Input Data
int64_t data[4] = {1, 2, 3, 4};
PaddleTensor tensor;
......@@ -430,6 +432,8 @@ TEST(Predictor, Run) {
auto predictor = CreatePredictor(config);
ASSERT_EQ(predictor->GetInputTypes().size(), 4UL);
ASSERT_EQ(predictor->GetOutputTypes().size(), 1UL);
ASSERT_EQ(predictor->GetOutputTensorShape().size(), 1UL);
auto w0 = predictor->GetInputHandle("firstw");
auto w1 = predictor->GetInputHandle("secondw");
......
......@@ -243,6 +243,19 @@ class PD_INFER_DECL PaddlePredictor {
/// \return Output tensor names.
virtual std::vector<std::string> GetOutputNames() { return {}; }
/// \brief Get the output shape of the model.
/// \return A map contains all the output names and shape defined in the
/// model.
virtual std::map<std::string, std::vector<int64_t>> GetOutputTensorShape() {
return {};
}
/// \brief Get the output type of the model.
/// \return A map contains all the output names and type defined in the model.
virtual std::map<std::string, paddle_infer::DataType> GetOutputTypes() {
return {};
}
/// \brief Get the input ZeroCopyTensor by name.
/// Be inherited by AnalysisPredictor, Only used in ZeroCopy scenarios.
/// The name is obtained from the GetInputNames() interface.
......
......@@ -92,6 +92,13 @@ class PD_INFER_DECL Predictor {
///
explicit Predictor(const Config& config);
///
/// \brief Get all input names and their corresponding shapes
///
/// \return the map of input names and shape
///
std::map<std::string, std::vector<int64_t>> GetInputTensorShape();
///
/// \brief Get all input names and their corresponding type
///
......@@ -136,6 +143,20 @@ class PD_INFER_DECL Predictor {
///
std::unique_ptr<Tensor> GetOutputHandle(const std::string& name);
///
/// \brief Get all output names and their corresponding shapes
///
/// \return the map of output names and shape
///
std::map<std::string, std::vector<int64_t>> GetOutputTensorShape();
///
/// \brief Get all output names and their corresponding type
///
/// \return the map of output names and type
///
std::map<std::string, DataType> GetOutputTypes();
///
/// \brief Clone to get the new predictor. thread safe.
///
......
......@@ -55,8 +55,9 @@ __pd_give PD_Config* PD_ConfigCreate() {
}
void PD_ConfigDestroy(__pd_take PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
delete reinterpret_cast<Config*>(config);
if (pd_config != NULL) {
delete reinterpret_cast<Config*>(pd_config);
}
}
void PD_ConfigSetModel(__pd_keep PD_Config* pd_config,
......@@ -116,9 +117,12 @@ PD_Bool PD_ConfigUseFcPadding(__pd_keep PD_Config* pd_config) {
void PD_ConfigEnableUseGpu(__pd_keep PD_Config* pd_config,
uint64_t memory_pool_init_size_mb,
int32_t device_id) {
int32_t device_id,
PD_PrecisionType precision_mode) {
CHECK_AND_CONVERT_PD_CONFIG;
config->EnableUseGpu(memory_pool_init_size_mb, device_id);
config->EnableUseGpu(memory_pool_init_size_mb,
device_id,
ConvertToCxxPrecisionType(precision_mode));
}
void PD_ConfigDisableGpu(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
......@@ -427,6 +431,14 @@ void PD_ConfigSetBfloat16Op(__pd_keep PD_Config* pd_config,
}
config->SetBfloat16Op(std::move(op_names));
}
void PD_ConfigEnableMkldnnInt8(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
config->EnableMkldnnInt8();
}
PD_Bool PD_ConfigMkldnnInt8Enabled(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
return config->mkldnn_int8_enabled();
}
PD_Bool PD_ConfigThreadLocalStreamEnabled(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
return config->thread_local_stream_enabled();
......@@ -484,6 +496,10 @@ void PD_ConfigEnableGpuMultiStream(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
config->EnableGpuMultiStream();
}
void PD_ConfigSetExecStream(__pd_keep PD_Config* pd_config, void* stream) {
CHECK_AND_CONVERT_PD_CONFIG;
return config->SetExecStream(stream);
}
void PD_ConfigPartiallyRelease(__pd_keep PD_Config* pd_config) {
CHECK_AND_CONVERT_PD_CONFIG;
config->PartiallyRelease();
......
......@@ -132,11 +132,13 @@ PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigUseFcPadding(
/// \param[in] memory_pool_init_size_mb initial size of the GPU memory pool in
/// MB.
/// \param[in] device_id device_id the GPU card to use.
/// \param[in] precision_mode the precision used in Paddle-GPU inference.
///
PADDLE_CAPI_EXPORT extern void PD_ConfigEnableUseGpu(
__pd_keep PD_Config* pd_config,
uint64_t memory_pool_init_size_mb,
int32_t device_id);
int32_t device_id,
PD_PrecisionType precision_mode);
///
/// \brief Turn off GPU.
///
......@@ -607,6 +609,22 @@ PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigMkldnnBfloat16Enabled(
PADDLE_CAPI_EXPORT extern void PD_ConfigSetBfloat16Op(
__pd_keep PD_Config* pd_config, size_t ops_num, const char** op_list);
///
/// \brief Turn on MKLDNN int8.
///
/// \param[in] pd_onfig config
///
PADDLE_CAPI_EXPORT extern void PD_ConfigEnableMkldnnInt8(
__pd_keep PD_Config* pd_config);
///
/// \brief A boolean state telling whether to use the MKLDNN int8.
///
/// \param[in] pd_onfig config
/// \return Whether to use the MKLDNN int8.
///
PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigMkldnnInt8Enabled(
__pd_keep PD_Config* pd_config);
///
/// \brief Enable the GPU multi-computing stream feature.
/// NOTE: The current behavior of this interface is to bind the computation
/// stream to the thread, and this behavior may be changed in the future.
......@@ -625,6 +643,12 @@ PADDLE_CAPI_EXPORT extern void PD_ConfigEnableGpuMultiStream(
PADDLE_CAPI_EXPORT extern PD_Bool PD_ConfigThreadLocalStreamEnabled(
__pd_keep PD_Config* pd_config);
///
/// \brief Set execution stream. If not set a stream will be created
/// internally.
///
PADDLE_CAPI_EXPORT extern void PD_ConfigSetExecStream(
__pd_keep PD_Config* pd_config, void* stream);
///
/// \brief Specify the memory buffer of program and parameter.
/// Used when model and params are loaded directly from memory.
///
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/inference/capi_exp/pd_predictor.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/capi_exp/pd_config.h"
#include "paddle/fluid/inference/capi_exp/pd_types.h"
#include "paddle/fluid/inference/capi_exp/pd_utils.h"
#include "paddle/fluid/inference/capi_exp/types_internal.h"
......@@ -38,7 +39,6 @@ __pd_give PD_Predictor* PD_PredictorCreate(__pd_take PD_Config* pd_config) {
paddle_infer::Config* config =
reinterpret_cast<paddle_infer::Config*>(pd_config);
pd_predictor->predictor = paddle_infer::CreatePredictor(*config);
delete config;
return pd_predictor;
}
......@@ -57,6 +57,30 @@ __pd_give PD_OneDimArrayCstr* PD_PredictorGetInputNames(
return paddle_infer::CvtVecToOneDimArrayCstr(names);
}
__pd_give PD_IOInfos* PD_PredictorGetInputInfos(
__pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR;
std::vector<std::string> names = predictor->GetInputNames();
std::map<std::string, std::vector<int64_t>> input_shapes =
predictor->GetInputTensorShape();
std::map<std::string, paddle_infer::DataType> input_dtypes =
predictor->GetInputTypes();
PD_IOInfos* input_infos = new PD_IOInfos;
input_infos->size = names.size();
input_infos->io_info = names.empty() ? NULL : new PD_IOInfo*[names.size()];
for (size_t i = 0; i < names.size(); i++) {
const std::string& name = names[i];
input_infos->io_info[i] = new PD_IOInfo;
input_infos->io_info[i]->name = paddle_infer::CvtStrToCstr(name);
input_infos->io_info[i]->shape =
paddle_infer::CvtVecToOneDimArrayInt64(input_shapes[name]);
input_infos->io_info[i]->dtype =
paddle_infer::CvtFromCxxDatatype(input_dtypes[name]);
}
return input_infos;
}
__pd_give PD_OneDimArrayCstr* PD_PredictorGetOutputNames(
__pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR;
......@@ -64,6 +88,30 @@ __pd_give PD_OneDimArrayCstr* PD_PredictorGetOutputNames(
return paddle_infer::CvtVecToOneDimArrayCstr(names);
}
__pd_give PD_IOInfos* PD_PredictorGetOutputInfos(
__pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR;
std::vector<std::string> names = predictor->GetOutputNames();
std::map<std::string, std::vector<int64_t>> output_shapes =
predictor->GetOutputTensorShape();
std::map<std::string, paddle_infer::DataType> output_dtypes =
predictor->GetOutputTypes();
PD_IOInfos* output_infos = new PD_IOInfos;
output_infos->size = names.size();
output_infos->io_info = names.empty() ? NULL : new PD_IOInfo*[names.size()];
for (size_t i = 0; i < names.size(); i++) {
const std::string& name = names[i];
output_infos->io_info[i] = new PD_IOInfo;
output_infos->io_info[i]->name = paddle_infer::CvtStrToCstr(name);
output_infos->io_info[i]->shape =
paddle_infer::CvtVecToOneDimArrayInt64(output_shapes[name]);
output_infos->io_info[i]->dtype =
paddle_infer::CvtFromCxxDatatype(output_dtypes[name]);
}
return output_infos;
}
size_t PD_PredictorGetInputNum(__pd_keep PD_Predictor* pd_predictor) {
CHECK_AND_CONVERT_PD_PREDICTOR;
return predictor->GetInputNames().size();
......
......@@ -30,6 +30,7 @@ typedef struct PD_Predictor PD_Predictor;
typedef struct PD_Config PD_Config;
typedef struct PD_Tensor PD_Tensor;
typedef struct PD_OneDimArrayCstr PD_OneDimArrayCstr;
typedef struct PD_IOInfos PD_IOInfos;
#ifdef __cplusplus
extern "C" {
......@@ -60,6 +61,14 @@ PADDLE_CAPI_EXPORT extern __pd_give PD_Predictor* PD_PredictorClone(
PADDLE_CAPI_EXPORT extern __pd_give PD_OneDimArrayCstr*
PD_PredictorGetInputNames(__pd_keep PD_Predictor* pd_predictor);
///
/// \brief Get the input infos(name/shape/dtype)
///
/// \param[in] pd_predictor predictor
/// \return input infos(name/shape/dtype)
///
PADDLE_CAPI_EXPORT extern __pd_give PD_IOInfos* PD_PredictorGetInputInfos(
__pd_keep PD_Predictor* pd_predictor);
///
/// \brief Get the output names
///
/// \param[in] pd_predictor predictor
......@@ -67,7 +76,14 @@ PD_PredictorGetInputNames(__pd_keep PD_Predictor* pd_predictor);
///
PADDLE_CAPI_EXPORT extern __pd_give PD_OneDimArrayCstr*
PD_PredictorGetOutputNames(__pd_keep PD_Predictor* pd_predictor);
///
/// \brief Get the output infos(name/shape/dtype)
///
/// \param[in] pd_predictor predictor
/// \return output infos(name/shape/dtype)
///
PADDLE_CAPI_EXPORT extern __pd_give PD_IOInfos* PD_PredictorGetOutputInfos(
__pd_keep PD_Predictor* pd_predictor);
///
/// \brief Get the input number
///
......
......@@ -29,6 +29,11 @@ typedef struct PD_OneDimArraySize {
size_t* data;
} PD_OneDimArraySize; // std::vector<size_t>
typedef struct PD_OneDimArrayInt64 {
size_t size;
int64_t* data;
} PD_OneDimArrayInt64; // std::vector<int64_t>
typedef struct PD_OneDimArrayCstr {
size_t size;
char** data;
......@@ -43,3 +48,14 @@ typedef struct PD_TwoDimArraySize {
size_t size;
PD_OneDimArraySize** data;
} PD_TwoDimArraySize; // std::vector<std::vector<size_t>>
typedef struct PD_IOInfo {
PD_Cstr* name;
PD_OneDimArrayInt64* shape;
PD_DataType dtype;
} PD_IOInfo; // input or output info
typedef struct PD_IOInfos {
size_t size;
PD_IOInfo** io_info;
} PD_IOInfos; // inputs or outputs info
......@@ -11,12 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/capi_exp/pd_utils.h"
#include <string>
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/capi_exp/pd_utils.h"
#include "paddle/fluid/inference/capi_exp/utils_internal.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -62,6 +60,7 @@
ONE_DIM_ARRAY_UTILS_FUNC_IMPL(int32_t, Int32, int)
ONE_DIM_ARRAY_UTILS_FUNC_IMPL(size_t, Size, size_t)
ONE_DIM_ARRAY_UTILS_FUNC_IMPL(int64_t, Int64, int64_t)
#undef ONE_DIM_ARRAY_UTILS_FUNC_IMPL
#undef CONVERT_ONE_DIM_ARRAY_TO_VEC
......@@ -178,6 +177,38 @@ TWO_DIM_ARRAY_UTILS_FUNC_IMPL(size_t, Size, size_t)
#undef CONVERT_VEC_TO_TWO_DIM_ARRAY
#undef DESTROY_TWO_DIM_ARRAY
#ifdef __cplusplus
extern "C" {
#endif
void PD_IOInfoDestroy(__pd_take PD_IOInfo* io_info) {
if (io_info != NULL) {
PD_CstrDestroy(io_info->name);
io_info->name = NULL;
PD_OneDimArrayInt64Destroy(io_info->shape);
io_info->shape = NULL;
delete io_info;
}
}
void PD_IOInfosDestroy(__pd_take PD_IOInfos* io_infos) {
if (io_infos != NULL) {
if (io_infos->size != 0) {
for (size_t index = 0; index < io_infos->size; ++index) {
PD_IOInfoDestroy(io_infos->io_info[index]);
}
io_infos->size = 0;
}
delete[] io_infos->io_info;
io_infos->io_info = NULL;
delete io_infos;
}
}
#ifdef __cplusplus
} // extern "C"
#endif
namespace paddle_infer {
PlaceType CvtToCxxPlaceType(PD_PlaceType place_type) {
......
......@@ -41,6 +41,14 @@ extern "C" {
PADDLE_CAPI_EXPORT extern void PD_OneDimArrayInt32Destroy(
__pd_take PD_OneDimArrayInt32* array);
///
/// \brief Destroy the PD_OneDimArrayInt64 object pointed to by the pointer.
///
/// \param[in] array pointer to the PD_OneDimArrayInt64 object.
///
PADDLE_CAPI_EXPORT extern void PD_OneDimArrayInt64Destroy(
__pd_take PD_OneDimArrayInt64* array);
///
/// \brief Destroy the PD_OneDimArrayCstr object pointed to by the pointer.
///
......@@ -74,6 +82,21 @@ PADDLE_CAPI_EXPORT extern void PD_TwoDimArraySizeDestroy(
///
PADDLE_CAPI_EXPORT extern void PD_CstrDestroy(__pd_take PD_Cstr* cstr);
///
/// \brief Destroy the PD_IOInfo object pointed to by the pointer.
///
/// \param[in] cstr pointer to the PD_IOInfo object.
///
PADDLE_CAPI_EXPORT extern void PD_IOInfoDestroy(__pd_take PD_IOInfo* io_info);
///
/// \brief Destroy the PD_IOInfos object pointed to by the pointer.
///
/// \param[in] cstr pointer to the PD_IOInfos object.
///
PADDLE_CAPI_EXPORT extern void PD_IOInfosDestroy(
__pd_take PD_IOInfos* io_infos);
#ifdef __cplusplus
} // extern "C"
#endif
......@@ -44,6 +44,16 @@ namespace paddle_infer {
__pd_give PD_OneDimArrayInt32* CvtVecToOneDimArrayInt32(
const std::vector<int>& vec);
///
/// \brief Convert the 'std::vector<int64_t>' object to a 'PD_OneDimArrayInt64'
/// object.
///
/// \param[in] vec source object.
/// \return target object.
///
__pd_give PD_OneDimArrayInt64* CvtVecToOneDimArrayInt64(
const std::vector<int64_t>& vec);
///
/// \brief Convert the 'PD_OneDimArrayInt32' object to a 'std::vector<int>'
/// object.
......@@ -54,6 +64,16 @@ __pd_give PD_OneDimArrayInt32* CvtVecToOneDimArrayInt32(
std::vector<int> CvtOneDimArrayToVecInt32(
__pd_keep const PD_OneDimArrayInt32* array);
///
/// \brief Convert the 'PD_OneDimArrayInt64' object to a 'std::vector<int64_t>'
/// object.
///
/// \param[in] array source object.
/// \return target object.
///
std::vector<int64_t> CvtOneDimArrayToVecInt64(
__pd_keep const PD_OneDimArrayInt64* array);
///
/// \brief Convert the 'std::vector<size_t>' object to a 'PD_OneDimArraySize'
/// object.
......
......@@ -157,7 +157,7 @@ func (config *Config) UseFcPadding() bool {
/// \param deviceId the GPU card to use.
///
func (config *Config) EnableUseGpu(memorySize uint64, deviceId int32) {
C.PD_ConfigEnableUseGpu(config.c, C.uint64_t(memorySize), C.int32_t(deviceId))
C.PD_ConfigEnableUseGpu(config.c, C.uint64_t(memorySize), C.int32_t(deviceId), 0)
}
///
......
......@@ -46,6 +46,7 @@ class CastOpConverter : public OpConverter {
layer->setOutputType(0, nvinfer1::DataType::kBOOL);
break;
case 2: // INT32 = 2
case 3: // INT64 = 3 there is no int64 in tensorrt subgraph
layer->setOutputType(0, nvinfer1::DataType::kINT32);
break;
case 4: // FP16 = 4
......
......@@ -19,6 +19,10 @@ limitations under the License. */
#include <string>
#include <vector>
#if defined(PADDLE_WITH_CUDA)
#include <cuda_runtime.h>
#endif
#include "paddle/fluid/inference/capi_exp/pd_inference_api.h"
#include "paddle/fluid/inference/tests/api/tester_helper.h"
......@@ -37,7 +41,7 @@ TEST(PD_Config, gpu_interface) {
PD_ConfigSetModel(config, prog_file.c_str(), param_file.c_str());
PD_ConfigSetOptimCacheDir(config, opt_cache_dir.c_str());
PD_ConfigEnableUseGpu(config, 100, 0);
PD_ConfigEnableUseGpu(config, 100, 0, 0);
bool use_gpu = PD_ConfigUseGpu(config);
EXPECT_TRUE(use_gpu);
int init_size = PD_ConfigMemoryPoolInitSizeMb(config);
......@@ -84,6 +88,14 @@ TEST(PD_Config, gpu_interface) {
bool thread_local_thread = PD_ConfigThreadLocalStreamEnabled(config);
EXPECT_TRUE(thread_local_thread);
#if defined(PADDLE_WITH_CUDA)
{
cudaStream_t external_stream;
cudaStreamCreate(&external_stream);
PD_ConfigSetExecStream(config, external_stream);
}
#endif
PD_ConfigDisableGpu(config);
PD_ConfigDestroy(config);
}
......@@ -104,7 +116,7 @@ TEST(PD_Config, use_gpu) {
const char* model_dir_ = PD_ConfigGetModelDir(config);
LOG(INFO) << model_dir_;
PD_ConfigEnableUseGpu(config, 100, 0);
PD_ConfigEnableUseGpu(config, 100, 0, 0);
bool use_gpu = PD_ConfigUseGpu(config);
EXPECT_TRUE(use_gpu);
int device_id = PD_ConfigGpuDeviceId(config);
......@@ -142,7 +154,7 @@ TEST(PD_Config, use_gpu) {
TEST(PD_Config, trt_int8) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_Config* config = PD_ConfigCreate();
PD_ConfigEnableUseGpu(config, 100, 0);
PD_ConfigEnableUseGpu(config, 100, 0, 0);
PD_ConfigEnableTensorRtEngine(
config, 1 << 20, 1, 3, PD_PRECISION_INT8, FALSE, TRUE);
bool trt_enable = PD_ConfigTensorRtEngineEnabled(config);
......@@ -153,7 +165,7 @@ TEST(PD_Config, trt_int8) {
TEST(PD_Config, trt_fp16) {
std::string model_dir = FLAGS_infer_model + "/mobilenet";
PD_Config* config = PD_ConfigCreate();
PD_ConfigEnableUseGpu(config, 100, 0);
PD_ConfigEnableUseGpu(config, 100, 0, 0);
PD_ConfigEnableTensorRtEngine(
config, 1 << 20, 1, 3, PD_PRECISION_HALF, FALSE, FALSE);
bool trt_enable = PD_ConfigTensorRtEngineEnabled(config);
......
......@@ -37,6 +37,9 @@ void predictor_run() {
PD_OneDimArrayCstr* input_names = PD_PredictorGetInputNames(predictor);
LOG(INFO) << "The inputs' size is: " << input_names->size;
EXPECT_EQ(input_names->size, 2u);
PD_IOInfos* in_infos = PD_PredictorGetInputInfos(predictor);
EXPECT_EQ(in_infos->size, 2u);
PD_IOInfos* out_infos = PD_PredictorGetOutputInfos(predictor);
int32_t shape_0[4] = {1, 3, 224, 224};
float data_0[1 * 3 * 224 * 224] = {0};
......@@ -79,6 +82,8 @@ void predictor_run() {
PD_TensorDestroy(input_1);
PD_TensorDestroy(input_0);
PD_OneDimArrayCstrDestroy(input_names);
PD_IOInfosDestroy(in_infos);
PD_IOInfosDestroy(out_infos);
PD_PredictorDestroy(predictor);
}
......
......@@ -85,6 +85,10 @@ TEST(PD_Config, interface) {
PD_ConfigEnableMkldnnBfloat16(config);
PD_ConfigSetBfloat16Op(config, 1, &ops_name);
PD_ConfigEnableMkldnnInt8(config);
bool mkldnn_int8_enabled = PD_ConfigMkldnnInt8Enabled(config);
EXPECT_TRUE(mkldnn_int8_enabled);
#endif
PD_ConfigEnableONNXRuntime(config);
......
......@@ -198,8 +198,7 @@ void RefcountedMemoryMapAllocation::close() {
MemoryMapAllocationPool::Instance().Insert(MemoryMapInfo(
flags_, map_size_ - mmap_alignment, ipc_name_, map_ptr_));
} else {
if (info->refcount == 0 &&
shm_open(ipc_name_.c_str(), O_RDWR, (mode_t)0600) != -1) {
if (info->refcount == 0) {
shm_unlink(ipc_name_.c_str());
VLOG(6) << "shm_unlink file: " << ipc_name_;
}
......
......@@ -305,24 +305,15 @@ PD_REGISTER_GENERAL_KERNEL(
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
namespace paddle {
namespace operators {
template void FeedDenseTensorKernel<phi::CustomContext>(
const phi::CustomContext& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out);
} // namespace operators
} // namespace paddle
#endif
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/kernel_registry.h"
#define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \
static paddle::framework::OpKernelRegistrar<phi::CustomPlace, __VA_ARGS__> \
......@@ -26,10 +27,30 @@ limitations under the License. */
paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \
__op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch();
#define REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL( \
kernel_name, dev_type, layout, kernel_fn) \
static phi::KernelRegistrar \
__reg_custom_device_phi_kernel_##kernel_name##_##backend##_##layout( \
phi::RegType::INNER, \
#kernel_name, \
dev_type, \
DATALAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \
PHI_VARIADIC_KERNEL(kernel_fn))
namespace paddle {
namespace operators {
template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out);
void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_type = dev_type.c_str();
/* see [Why use single type kernel] */
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
......@@ -66,9 +87,16 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL(
feed_dense_tensor,
device_type,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>);
#endif
}
} // namespace operators
} // namespace paddle
#undef REGISTER_OP_CUSTOM_DEVICE_KERNEL
#undef REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL
......@@ -56,6 +56,7 @@ class SequencePadOp : public framework::OperatorWithKernel {
auto pad_value_dims = ctx->GetInputDim("PadValue");
PADDLE_ENFORCE_EQ(
pad_value_dims == phi::make_ddim({1}) ||
pad_value_dims == phi::make_ddim({}) ||
pad_value_dims == time_step_dims,
true,
platform::errors::InvalidArgument(
......
......@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string.h>
#include <memory>
#include <sstream>
#include <string>
......@@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s",
phi::DataTypeToString(dtype)));
op->SetAttr("value", value.to<float>());
if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
......
......@@ -32,7 +32,7 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) {
auto tmp = pow<T>(out, 2.0);
tmp = scale<T>(tmp, -1.0, 1.0, true);
auto grad_x_tmp = multiply<T>(grad_out, tmp);
set_output<T>(grad_x_tmp.impl(), grad_x);
set_output<T>(grad_x_tmp, grad_x);
}
template <typename T>
......@@ -53,7 +53,7 @@ void subtract_grad(const Tensor& x,
auto dy_reduce_res = sum<T>(
scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy);
set_output<T>(dy_tmp, dy);
}
} else {
by_pass<T>(scale_out_grad, dy);
......@@ -69,7 +69,7 @@ void subtract_grad(const Tensor& x,
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx);
set_output<T>(dx_tmp, dx);
}
} else {
by_pass<T>(out_grad, dx);
......@@ -94,7 +94,7 @@ void add_grad(const Tensor& x,
auto dy_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy);
set_output<T>(dy_tmp, dy);
}
} else {
......@@ -111,7 +111,7 @@ void add_grad(const Tensor& x,
auto dx_reduce_res =
sum<T>(out_grad, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx);
set_output<T>(dx_tmp, dx);
}
} else {
by_pass<T>(out_grad, dx);
......@@ -139,22 +139,26 @@ void sum_grad(const Tensor& x,
reduce_all = false;
}
auto x_grad_tmp = Tensor();
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
if (x_dim_size == 1) {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
} else {
if (!keepdim) {
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
for (int64_t i = 1; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
axis_ = axis.GetData();
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}
auto out_grad_ = unsqueeze<T>(out_grad, axis_);
x_grad_tmp = expand<T>(out_grad_, IntArray(x_dim));
} else {
x_grad_tmp = expand<T>(out_grad, IntArray(x_dim));
}
set_output<T>(x_grad_tmp.impl(), x_grad);
set_output<T>(x_grad_tmp, x_grad);
}
template <typename T>
......@@ -175,36 +179,36 @@ void divide_grad(const Tensor& x,
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
set_output<T>(dy_res.impl(), dy);
set_output<T>(dy_res, dy);
} else {
auto dy_reduce_res =
sum<T>(dy_res, phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp.impl(), dy);
set_output<T>(dy_tmp, dy);
}
} else {
set_output<T>(dy_res.impl(), dy);
set_output<T>(dy_res, dy);
}
} // indicate we will compute dy
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0);
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad);
if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
set_output<T>(dx_res.impl(), dx);
set_output<T>(dx_res, dx);
} else {
auto dx_reduce_res =
sum<T>(dx_res, phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp.impl(), dx);
set_output<T>(dx_tmp, dx);
}
} else {
set_output<T>(dx_res.impl(), dx);
set_output<T>(dx_res, dx);
}
} // indicate we will compute dx
}
......@@ -215,7 +219,7 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
auto div_x = full<T>(phi::vectorize(out.dims()), 0.5);
auto tmp = divide<T>(div_x, out);
auto x_grad_tmp = multiply<T>(out_grad, tmp);
set_output<T>(x_grad_tmp.impl(), x_grad);
set_output<T>(x_grad_tmp, x_grad);
}
}
......@@ -231,7 +235,7 @@ void multiply_grad(const Tensor& x,
if (x.dims() != y.dims()) {
auto axes = get_reduce_dims(x.dims(), y.dims());
if (!axes.size()) {
set_output<T>(x_grad_unreduce.impl(), x_grad);
set_output<T>(x_grad_unreduce, x_grad);
} else {
auto x_grad_reduced = sum<T>(x_grad_unreduce,
phi::vectorize(axes),
......@@ -240,10 +244,10 @@ void multiply_grad(const Tensor& x,
if (x_grad_reduced.dims().size() != x.dims().size()) {
x_grad_reduced = reshape<T>(x_grad_reduced, x.shape());
}
set_output<T>(x_grad_reduced.impl(), x_grad);
set_output<T>(x_grad_reduced, x_grad);
}
} else {
set_output<T>(x_grad_unreduce.impl(), x_grad);
set_output<T>(x_grad_unreduce, x_grad);
}
}
if (y_grad) {
......@@ -251,7 +255,7 @@ void multiply_grad(const Tensor& x,
if (y.dims() != x.dims()) {
auto axes = get_reduce_dims(y.dims(), x.dims());
if (!axes.size()) {
set_output<T>(y_grad_unreduce.impl(), y_grad);
set_output<T>(y_grad_unreduce, y_grad);
} else {
auto y_grad_reduced = sum<T>(y_grad_unreduce,
phi::vectorize(axes),
......@@ -260,10 +264,10 @@ void multiply_grad(const Tensor& x,
if (y_grad_reduced.dims().size() != y.dims().size()) {
y_grad_reduced = reshape<T>(y_grad_reduced, y.shape());
}
set_output<T>(y_grad_reduced.impl(), y_grad);
set_output<T>(y_grad_reduced, y_grad);
}
} else {
set_output<T>(y_grad_unreduce.impl(), y_grad);
set_output<T>(y_grad_unreduce, y_grad);
}
}
}
......@@ -284,7 +288,7 @@ void expand_grad(const Tensor& x,
if (reduced.dims().size() != x.dims().size()) {
reduced = reshape<T>(reduced, x.shape());
}
set_output<T>(reduced.impl(), x_grad);
set_output<T>(reduced, x_grad);
}
} else {
by_pass<T>(out_grad, x_grad);
......@@ -295,7 +299,7 @@ void expand_grad(const Tensor& x,
template <typename T>
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
set_output<T>(multiply<T>(out_grad, out).impl(), x_grad);
set_output<T>(multiply<T>(out_grad, out), x_grad);
}
}
......
......@@ -49,7 +49,6 @@ void set_output<Tensor>(const paddle::experimental::Tensor& x_tmp,
template <>
void by_pass<Tensor>(const paddle::experimental::Tensor& x, Tensor* out) {
set_output<Tensor>(x, out);
// out->set_impl(x.impl());
}
} // namespace prim
......
......@@ -69,7 +69,6 @@ void by_pass<DescTensor>(const paddle::experimental::Tensor& x,
op->InferVarType(block);
op->InferShape(*block);
set_output<DescTensor>(new_out, out);
// out->set_impl(new_out.impl());
}
} // namespace prim
......
......@@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle::experimental::Tensor out0 = tanh_ad_func(tensor0);
std::vector<paddle::experimental::Tensor> outs0 = {out0};
// Disable prim
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
egr::Backward(outs0, {}, false);
paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
::egr::Backward(outs1, {}, false);
VLOG(7)
......@@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
}
TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}
} // namespace prim
......
......@@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
}
TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}
} // namespace prim
......
......@@ -18,6 +18,7 @@ namespace paddle {
namespace prim {
StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_prim_ = false;
thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
} // namespace prim
} // namespace paddle
......@@ -56,9 +56,18 @@ class StaticCompositeContext {
return generator_->Generate(key);
}
void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; }
void SetBwdPrimEnabled(bool enable_prim) { enable_bwd_prim_ = enable_prim; }
bool IsPrimEnabled() { return enable_prim_; }
bool IsBwdPrimEnabled() { return enable_bwd_prim_; }
void SetFwdPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; }
bool IsFwdPrimEnabled() { return enable_fwd_prim_; }
void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim;
}
private:
StaticCompositeContext()
......@@ -66,7 +75,8 @@ class StaticCompositeContext {
framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_;
static thread_local bool enable_prim_;
static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
};
......
......@@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
namespace paddle {
namespace prim {
bool PrimCommonUtils::IsPrimEnabled() {
return StaticCompositeContext::Instance().IsPrimEnabled();
bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsBwdPrimEnabled();
}
void PrimCommonUtils::SetPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim);
void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}
bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
}
void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
}
void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
}
} // namespace prim
} // namespace paddle
......@@ -18,8 +18,11 @@ namespace paddle {
namespace prim {
class PrimCommonUtils {
public:
static bool IsPrimEnabled();
static void SetPrimEnabled(bool enabled);
static bool IsBwdPrimEnabled();
static void SetBwdPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
};
} // namespace prim
} // namespace paddle
......@@ -65,6 +65,7 @@ struct npy_format_descriptor<paddle::platform::float16> {
namespace paddle {
namespace pybind {
using paddle::distributed::DependType;
using paddle::distributed::DistModel;
using paddle::distributed::DistModelConfig;
using paddle::distributed::DistModelDataBuf;
......@@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) {
.def(
"run", &FleetExecutor::Run, py::call_guard<py::gil_scoped_release>());
py::enum_<DependType>(*m, "DependType")
.value("NORMAL", DependType::NORMAL)
.value("LOOP", DependType::LOOP)
.value("STOP_LOOP", DependType::STOP_LOOP);
py::class_<TaskNode>(*m, "TaskNode")
.def(py::init<framework::ProgramDesc*,
int64_t,
int64_t,
int64_t,
int64_t>())
.def(py::init<framework::ProgramDesc*, int64_t, int64_t, int64_t>())
.def(py::init<int32_t,
const std::vector<framework::OpDesc*>&,
int64_t,
int64_t,
int64_t,
int64_t>())
.def("task_id", &TaskNode::task_id)
.def("add_upstream_task", &TaskNode::AddUpstreamTask)
......@@ -183,6 +183,7 @@ void BindFleetExecutor(py::module* m) {
.def("set_run_pre_steps", &TaskNode::SetRunPerSteps)
.def("set_run_at_offset", &TaskNode::SetRunAtOffset)
.def("set_type", &TaskNode::SetType)
.def("set_cond_var_name", &TaskNode::SetCondVarName)
.def("role", &TaskNode::role)
.def("init", [](TaskNode& self) { self.Init(); })
.def("set_program", &TaskNode::SetProgram);
......
......@@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str();
});
m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled);
m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled);
m.def("__set_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetBwdPrimEnabled);
m.def("_is_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsBwdPrimEnabled);
m.def("__set_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetFwdPrimEnabled);
m.def("_is_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("set_num_threads", &platform::SetNumThreads);
m.def("disable_signal_handler", &DisableSignalHandler);
......@@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better
// performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) {
VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set,
&grad_to_var,
......
......@@ -42,7 +42,7 @@
kernel :
func : add_grad
no_need_buffer : x, y
composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : add_grad(x, y, out_grad, axis)
backward : add_double_grad
inplace : (out_grad -> x_grad)
......@@ -390,7 +390,7 @@
param : [x, y]
kernel :
func : divide_grad
composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1)
composite : divide_grad(x, y, out, out_grad, -1)
backward : divide_double_grad
- backward_op : dropout_grad
......@@ -1319,7 +1319,7 @@
kernel :
func : subtract_grad
no_need_buffer : x, y
composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : subtract_grad(x, y, out_grad, axis)
backward : subtract_double_grad
inplace : (out_grad -> x_grad)
......
......@@ -112,42 +112,6 @@ static void AppendActivation(const OneDNNContext& dev_ctx,
}
}
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
static std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
template <typename T,
typename TForward,
typename TBackward = onednn_dummy_primitive,
......@@ -1756,22 +1720,22 @@ static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
paddle::platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank,
axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_EQ(
in_rank,
axis_size,
phi::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
paddle::platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
......@@ -67,10 +67,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64})},
{"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"bitwise_and", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_or", XPUKernelSet({phi::DataType::BOOL})},
{"bitwise_xor", XPUKernelSet({phi::DataType::BOOL})},
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
{"c_allgather",
XPUKernelSet({phi::DataType::FLOAT16,
......@@ -109,6 +106,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"clip", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip_by_norm", XPUKernelSet({phi::DataType::FLOAT32})},
{"clip_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::INT32})},
{"coalesce_tensor",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"concat_grad",
......@@ -374,6 +373,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"max_pool2d_with_index",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"max_pool2d_with_index_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"matmul_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"matmul_v2_grad",
......@@ -435,7 +438,10 @@ XPUOpMap& get_kl2_ops() {
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT8,
phi::DataType::INT64})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad",
......
......@@ -146,17 +146,17 @@ PADDLE_DEFINE_EXPORTED_bool(
* CUDA related related FLAG
* Name: FLAGS_gemm_use_half_precision_compute_type
* Since Version: 2.4
* Value Range: bool, default=true
* Value Range: bool, default=false
* Example:
* Note: whether to use fp16 compute type when the input and output is fp16,
* faster but it may loss precision.
*/
PADDLE_DEFINE_EXPORTED_bool(
gemm_use_half_precision_compute_type,
true,
false,
"Whether to use fp16 compute type when the input and output is fp16, "
"faster but it may loss precision in most case. If true, the compute "
"type will be set to fp32. Default is true.");
"type will be set to fp16. Default is false.");
/**
* CUDA related FLAG
......
......@@ -4596,17 +4596,26 @@ void UniqueRawInferMeta(const MetaTensor& x,
MetaTensor* index,
MetaTensor* counts) {
if (!is_sorted) {
PADDLE_ENFORCE_EQ(
x.dims().size(),
1,
phi::errors::InvalidArgument("The Input(X) should be 1-D Tensor, "
"But now the dims of Input(X) is %d.",
x.dims().size()));
PADDLE_ENFORCE_EQ(x.dims().size() == 1 || x.dims().size() == 0,
true,
phi::errors::InvalidArgument(
"The Input(X) should be 0-D or 1-D Tensor, "
"But now the dims of Input(X) is %d.",
x.dims().size()));
out->set_dims(phi::make_ddim({-1}));
index->set_dims(x.dims());
return;
}
if (x.dims().size() == 0) {
PADDLE_ENFORCE_EQ(axis.empty(),
true,
phi::errors::InvalidArgument(
"The Input(X) with 0-D Tensor, axis must be None"
"But now the axis is %d.",
axis[0]));
}
if (axis.empty()) {
out->set_dims(phi::make_ddim({-1}));
if (return_inverse) {
......
......@@ -43,14 +43,13 @@ enum class AlgorithmType {
kConvForward = 1,
kConvBackwardData = 2,
kConvBackwardFilter = 3,
kTranspose = 4,
#ifdef PADDLE_WITH_CUDNN_FRONTEND
kConvForwardV8 = 4,
kConvBackwardDataV8 = 5,
kConvBackwardFilterV8 = 6,
kTranspose = 7,
kConvForwardV8 = 5,
kConvBackwardDataV8 = 6,
kConvBackwardFilterV8 = 7,
kAlgorithmCount = 8
#else
kTranspose = 4,
kAlgorithmCount = 5
#endif
};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
......@@ -463,12 +463,17 @@ void DispatchConcatKernel(const phi::GPUContext& ctx,
constexpr IndexT MaxVecSize = 16 / sizeof(T);
bool find_vecsize_flag = false;
IndexT dispatch_vec_size = 1;
auto output_data = reinterpret_cast<std::uintptr_t>(output->data());
for (IndexT vec_size = MaxVecSize; vec_size > 0; vec_size /= 2) {
for (IndexT idx = 0; idx < in_num + 1; idx++) {
const IndexT mov_size = vec_size * sizeof(T);
for (IndexT idx = 1; idx < in_num + 1; idx++) {
auto input_data = reinterpret_cast<std::uintptr_t>(inputs_data[idx - 1]);
// Since input_cols[0] is 0, we need to jump.
const IndexT input_col = inputs_col[idx + 1] - inputs_col[idx];
if (input_col % vec_size == 0) {
if (idx == in_num - 1) {
const IndexT input_col = inputs_col[idx] - inputs_col[idx - 1];
if (input_col % vec_size == 0 && output_data % mov_size == 0 &&
input_data % mov_size == 0) {
if (idx == in_num) {
find_vecsize_flag = true;
}
} else {
......
......@@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx,
out_tmp.Resize(out->dims());
out_tmp = *out;
phi::SqueezeInferKernel<T, Context>(dev_ctx, out_tmp, {-1}, out);
phi::Squeeze<T, Context>(dev_ctx, out_tmp, {-1}, out);
} else {
PADDLE_ENFORCE_EQ(
x_dim[x_dim_size - 1],
......
......@@ -19,37 +19,64 @@
namespace phi {
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
int new_size) {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
const std::vector<int64_t> &out_dims,
std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims,
std::vector<int64_t> *out_bd_dims,
bool trans_x,
bool trans_y) {
if (x_dims.size() == 1) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[1];
(*x_bd_dims)[x_bd_dims->size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[x_bd_dims->size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[x_bd_dims->size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[y_bd_dims->size() - 1] = y_dims[1];
(*y_bd_dims)[y_bd_dims->size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[y_bd_dims->size() - y_dims.size() + i] = y_dims[i];
}
}
for (size_t i = 0; i < x_bd_dims->size() - 2; ++i) {
(*out_bd_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
int h_idx = trans_x ? x_bd_dims->size() - 1 : x_bd_dims->size() - 2;
int w_idx = trans_y ? y_bd_dims->size() - 2 : y_bd_dims->size() - 1;
return new_dims;
(*out_bd_dims)[x_bd_dims->size() - 2] = (*x_bd_dims)[h_idx];
(*out_bd_dims)[y_bd_dims->size() - 1] = (*y_bd_dims)[w_idx];
}
template <typename T>
void CalculateGradMatrixDims(const OneDNNContext &dev_ctx,
DenseTensor *dx_tmp,
DenseTensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i];
for (size_t i = 0; i < dx_bd_dims->size() - 2; ++i) {
if ((*dx_bd_dims)[i] != (*dy_bd_dims)[i]) {
if ((*dx_bd_dims)[i] == 1) {
(*dx_bd_dims)[i] = (*dy_bd_dims)[i];
} else {
(*dy_bd_dims)[i] = dx_dims[i];
(*dy_bd_dims)[i] = (*dx_bd_dims)[i];
}
}
}
dx_tmp->Resize(make_ddim((*dx_bd_dims)));
dx_tmp->Resize(make_ddim(*dx_bd_dims));
dev_ctx.template Alloc<T>(dx_tmp);
dy_tmp->Resize(make_ddim((*dy_bd_dims)));
dy_tmp->Resize(make_ddim(*dy_bd_dims));
dev_ctx.template Alloc<T>(dy_tmp);
}
......@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
const DenseTensor *dx_tmp,
DenseTensor *dx,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) {
const std::vector<int64_t> &x_dims) {
funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
......@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
dev_ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
x_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
......@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
template <typename T, typename Context>
......@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx,
size_t ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
}
if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
if (dout_dims.size() != ndims) {
dout_dims = ExtendDimsWithOnes(dout_dims, ndims);
}
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
DenseTensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
std::vector<int64_t> dout_bd_dims(ndims, 1);
std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(x_dims,
y_dims,
dout_dims,
&x_bd_dims,
&y_bd_dims,
&dout_bd_dims,
transpose_x,
transpose_y);
std::vector<int64_t> dx_bd_dims(x_bd_dims);
std::vector<int64_t> dy_bd_dims(y_bd_dims);
CalculateGradMatrixDims<T>(
dev_ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
dev_ctx, &dx_tmp, &dy_tmp, &dx_bd_dims, &dy_bd_dims);
if (transpose_x && transpose_y) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, true, true, &dx_tmp);
dev_ctx, y, dout, y_bd_dims, dout_bd_dims, true, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, true, &dy_tmp);
dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, true, &dy_tmp);
} else if (transpose_x) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, false, true, &dx_tmp);
dev_ctx, y, dout, y_bd_dims, dout_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, false, false, &dy_tmp);
dev_ctx, x, dout, x_bd_dims, dout_bd_dims, false, false, &dy_tmp);
} else if (transpose_y) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, false, &dx_tmp);
dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, false, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, false, &dy_tmp);
dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, false, &dy_tmp);
} else {
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, true, &dx_tmp);
dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, true, false, &dy_tmp);
dev_ctx, x, dout, x_bd_dims, dout_bd_dims, true, false, &dy_tmp);
}
if (x_dims != dx_bd_dims) {
if (x_bd_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dx_tmp, dx, x_dims, vectorize(x.dims()));
dev_ctx, &dx_tmp, dx, dx_bd_dims, x_bd_dims);
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
if (y_bd_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dy_tmp, dy, y_dims, vectorize(y.dims()));
dev_ctx, &dy_tmp, dy, dy_bd_dims, y_bd_dims);
} else {
*dy = std::move(dy_tmp);
}
dx->set_mem_desc(x.mem_desc());
dx->Resize(x.dims());
dx->set_mem_desc(x.mem_desc().reshape(vectorize(x.dims())));
dy->set_mem_desc(y.mem_desc());
dy->Resize(y.dims());
dy->set_mem_desc(y.mem_desc().reshape(vectorize(y.dims())));
}
template <typename T, typename Context>
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -27,7 +27,8 @@ void SumKernel(const Context& dev_ctx,
bool keep_dim,
DenseTensor* out) {
bool reduce_all = recompute_reduce_all(x, dims);
SumRawKernel<T>(dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
SumRawKernel<T, Context>(
dev_ctx, x, dims, keep_dim, reduce_all, out_dtype, out);
}
} // namespace phi
......@@ -82,5 +83,8 @@ PD_REGISTER_KERNEL(
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(sum, XPU, ALL_LAYOUT, phi::SumKernel, float) {}
PD_REGISTER_KERNEL(
sum, XPU, ALL_LAYOUT, phi::SumKernel, float, int8_t, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......
......@@ -25,11 +25,7 @@ void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
auto x_dims = x.dims();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true);
out->Resize(out_dims);
auto out_dims = out->dims();
dev_ctx.template Alloc<T>(out);
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims.
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h"
namespace phi {
......@@ -33,4 +34,14 @@ void SqueezeKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
void Squeeze(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
MetaTensor meta_out(out);
SqueezeInferMeta(x, axes, &meta_out);
SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -19,51 +19,18 @@
namespace phi {
template <typename T, typename Context>
void BitwiseAndKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
ctx.template Alloc<T>(out);
int r = xpu::logical_and(
ctx.x_context(), x.data<T>(), y.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise and");
}
template <typename T, typename Context>
void BitwiseOrKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
ctx.template Alloc<T>(out);
int r = xpu::logical_or(
ctx.x_context(), x.data<T>(), y.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise or");
}
template <typename T, typename Context>
void BitwiseXorKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
ctx.template Alloc<T>(out);
int r = xpu::logical_xor(
ctx.x_context(), x.data<T>(), y.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise xor");
}
template <typename T, typename Context>
void BitwiseNotKernel(const Context& ctx,
const DenseTensor& x,
DenseTensor* out) {
using XPUDataType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(out);
int r =
xpu::logical_not(ctx.x_context(), x.data<T>(), out->data<T>(), x.numel());
int r = xpu::logical_not(ctx.x_context(),
reinterpret_cast<const XPUDataType*>(x.data<T>()),
reinterpret_cast<XPUDataType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "bitwise not");
}
} // namespace phi
PD_REGISTER_KERNEL(bitwise_and, XPU, ALL_LAYOUT, phi::BitwiseAndKernel, bool) {}
PD_REGISTER_KERNEL(bitwise_or, XPU, ALL_LAYOUT, phi::BitwiseOrKernel, bool) {}
PD_REGISTER_KERNEL(bitwise_xor, XPU, ALL_LAYOUT, phi::BitwiseXorKernel, bool) {}
PD_REGISTER_KERNEL(bitwise_not, XPU, ALL_LAYOUT, phi::BitwiseNotKernel, bool) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/clip_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void ClipGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const Scalar& min,
const Scalar& max,
DenseTensor* x_grad) {
ctx.template Alloc<T>(x_grad);
using XPUDataType = typename XPUTypeTrait<T>::Type;
int r =
xpu::clip_grad(ctx.x_context(),
reinterpret_cast<const XPUDataType*>(x.data<T>()),
reinterpret_cast<const XPUDataType*>(out_grad.data<T>()),
reinterpret_cast<XPUDataType*>(x_grad->data<T>()),
x.numel(),
min.to<T>(),
max.to<T>());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(
clip_grad, XPU, ALL_LAYOUT, phi::ClipGradKernel, float, int) {}
......@@ -104,7 +104,6 @@ void Pool2dGradKernel(const Context& ctx,
}
if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
r = xpu::max_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
......@@ -142,6 +141,67 @@ void Pool2dGradKernel(const Context& ctx,
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2dgrad");
}
template <typename T, typename Context>
void MaxPool2dWithIndexGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& mask,
const DenseTensor& dout,
const std::vector<int>& kernel_size,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool global_pooling,
bool adaptive,
DenseTensor* dx) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(dx);
auto input_grad = reinterpret_cast<XPUType*>(dx->data<T>());
std::vector<int> ksize(kernel_size);
std::vector<int> strides(strides_t);
std::vector<int> paddings(paddings_t);
const auto* index_data = mask.data<int>();
PADDLE_ENFORCE_NOT_NULL(index_data,
errors::NotFound("index data should not be nullptr"));
PADDLE_ENFORCE_EQ(
ksize.size(),
2,
phi::errors::InvalidArgument("The Pool2d XPU OP only support 2 "
"dimension pooling!, but received "
"%d-dimension pool kernel size",
ksize.size()));
global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(dx->dims()[i + 2]);
}
}
const int n = dx->dims()[0];
const int c = dx->dims()[1];
const int in_h = dx->dims()[2];
const int in_w = dx->dims()[3];
auto output_grad = reinterpret_cast<const XPUType*>(dout.data<T>());
int r = xpu::Error_t::SUCCESS;
// pass a nullptr as input to XDNN is fine as long as index_data exists
r = xpu::max_pool2d_grad<XPUType>(ctx.x_context(),
/*input*/ nullptr,
/*output*/ nullptr,
index_data,
output_grad,
input_grad,
n,
c,
in_h,
in_w,
ksize,
strides,
paddings,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(pool2d_grad,
......@@ -150,3 +210,9 @@ PD_REGISTER_KERNEL(pool2d_grad,
phi::Pool2dGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index_grad,
XPU,
ALL_LAYOUT,
phi::MaxPool2dWithIndexGradKernel,
float,
phi::dtype::float16) {}
......@@ -154,7 +154,72 @@ void Pool2dKernel(const Context& ctx,
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2d");
}
template <typename T, typename Context>
void MaxPool2dWithIndexKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool global_pooling,
bool adaptive,
DenseTensor* out,
DenseTensor* mask) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<int>(mask);
auto* index_data = mask->data<int>();
std::vector<int> ksize(kernel_size);
std::vector<int> strides(strides_t);
std::vector<int> paddings(paddings_t);
PADDLE_ENFORCE_EQ(ksize.size(),
2,
phi::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!"));
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1),
true,
phi::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
global_pooling = global_pooling || (adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(x.dims()[i + 2]);
}
}
const int n = x.dims()[0];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
auto input = reinterpret_cast<const XPUType*>(x.data<T>());
ctx.template Alloc<T>(out);
auto output = reinterpret_cast<XPUType*>(out->data<T>());
int r = xpu::Error_t::SUCCESS;
r = xpu::max_pool2d<XPUType>(ctx.x_context(),
input,
output,
index_data,
n,
c,
in_h,
in_w,
ksize,
strides,
paddings,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "max_pool2d_with_index");
}
} // namespace phi
PD_REGISTER_KERNEL(
pool2d, XPU, ALL_LAYOUT, phi::Pool2dKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index,
XPU,
ALL_LAYOUT,
phi::MaxPool2dWithIndexKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -46,4 +46,5 @@ void SumRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(sum_raw, XPU, ALL_LAYOUT, phi::SumRawKernel, float) {}
PD_REGISTER_KERNEL(
sum_raw, XPU, ALL_LAYOUT, phi::SumRawKernel, float, int8_t, int64_t) {}
......@@ -26,6 +26,14 @@ void TransposeGradKernel(const Context& dev_ctx,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(x_grad);
if (x_grad->numel() == 0) {
return;
}
if (axis.size() == 0) {
phi::Copy<Context>(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
return;
}
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
......
......@@ -29,6 +29,7 @@ WHITE_LIST = {
'conv2d',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
......
此差异已折叠。
......@@ -41,11 +41,9 @@ class HybridParallelGradScaler:
optimize_ops, params_grads = (None, None)
if self._found_inf:
self._cache_founf_inf = True
else:
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = False
optimizer._set_auxiliary_var('found_inf', self._found_inf)
optimize_ops, params_grads = optimizer.minimize(*args, **kwargs)
self._cache_founf_inf = optimizer._get_auxiliary_var('found_inf')
if self._use_dynamic_loss_scaling:
self._update()
......
......@@ -19,10 +19,10 @@ from types import MethodType
import numpy as np
import paddle
from paddle import _legacy_C_ops
from paddle import _C_ops, _legacy_C_ops
from paddle.common_ops_import import dygraph_only
from paddle.fluid import core
from paddle.fluid.dygraph import to_variable
from paddle.framework import core
from paddle.nn import clip
......@@ -231,6 +231,9 @@ def GroupShardedScaler(scaler):
param_grads_fp16,
temp_found_inf_fp16,
)
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp16
)
if len(param_grads_fp32):
_legacy_C_ops.check_finite_and_unscale(
param_grads_fp32,
......@@ -238,15 +241,17 @@ def GroupShardedScaler(scaler):
param_grads_fp32,
temp_found_inf_fp32,
)
self._found_inf = _C_ops.bitwise_or(
self._found_inf, temp_found_inf_fp32
)
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
is_found_inf = paddle.to_tensor([self._found_inf], dtype="int32")
self._found_inf = self._found_inf.cast("int32")
paddle.distributed.all_reduce(
is_found_inf, op=paddle.distributed.ReduceOp.SUM, group=None
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None
)
self._found_inf = is_found_inf.numpy()[0]
self._found_inf = self._found_inf.cast("bool")
scaler._unscale = MethodType(unscale_method, scaler)
return scaler
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册