提交 75aca6d4 编写于 作者: P phlrain

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into move_slice_to_pten

...@@ -51,3 +51,5 @@ paddle/infrt/dialect/pd_ops_info.h ...@@ -51,3 +51,5 @@ paddle/infrt/dialect/pd_ops_info.h
.lit_test_times.txt .lit_test_times.txt
paddle/infrt/tests/dialect/Output paddle/infrt/tests/dialect/Output
paddle/infrt/tests/lit.cfg.py paddle/infrt/tests/lit.cfg.py
paddle/fluid/pybind/eager_final_state_op_function_impl.h
paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h
...@@ -243,6 +243,7 @@ option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup ji ...@@ -243,6 +243,7 @@ option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup ji
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF) option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON) option(WITH_POCKETFFT "Compile with pocketfft support" ON)
option(WITH_RECORD_BUILDTIME "Compile PaddlePaddle with record all targets build time" OFF) option(WITH_RECORD_BUILDTIME "Compile PaddlePaddle with record all targets build time" OFF)
option(WITH_CUSTOM_DEVICE "Compile with custom device support" OFF)
if(WITH_RECORD_BUILDTIME) if(WITH_RECORD_BUILDTIME)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh") set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh")
...@@ -265,6 +266,10 @@ if(SANITIZER_TYPE AND NOT "${SANITIZER_TYPE}" MATCHES "^(Address|Leak|Memory|Thr ...@@ -265,6 +266,10 @@ if(SANITIZER_TYPE AND NOT "${SANITIZER_TYPE}" MATCHES "^(Address|Leak|Memory|Thr
return() return()
endif() endif()
if (LINUX AND NOT WITH_CUSTOM_DEVICE AND NOT ON_INFER)
set(WITH_CUSTOM_DEVICE ON)
endif()
if(WIN32) if(WIN32)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
MESSAGE(WARNING MESSAGE(WARNING
......
...@@ -219,3 +219,7 @@ endif(ON_INFER) ...@@ -219,3 +219,7 @@ endif(ON_INFER)
if(WITH_CRYPTO) if(WITH_CRYPTO)
add_definitions(-DPADDLE_WITH_CRYPTO) add_definitions(-DPADDLE_WITH_CRYPTO)
endif(WITH_CRYPTO) endif(WITH_CRYPTO)
if(WITH_CUSTOM_DEVICE AND NOT WIN32)
add_definitions(-DPADDLE_WITH_CUSTOM_DEVICE)
endif()
...@@ -55,6 +55,7 @@ IF(NOT WIN32) ...@@ -55,6 +55,7 @@ IF(NOT WIN32)
INSTALL_COMMAND make install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR> INSTALL_COMMAND make install NO_SHARED=1 NO_LAPACK=1 PREFIX=<INSTALL_DIR>
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_BYPRODUCTS ${CBLAS_LIBRARIES}
) )
ELSE(NOT WIN32) ELSE(NOT WIN32)
SET(CBLAS_LIBRARIES SET(CBLAS_LIBRARIES
...@@ -83,6 +84,8 @@ ELSE(NOT WIN32) ...@@ -83,6 +84,8 @@ ELSE(NOT WIN32)
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${CBLAS_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${CBLAS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
# ninja need to know where openblas.lib comes from
BUILD_BYPRODUCTS ${CBLAS_LIBRARIES}
) )
SET(OPENBLAS_SHARED_LIB ${CBLAS_INSTALL_DIR}/bin/openblas${CMAKE_SHARED_LIBRARY_SUFFIX}) SET(OPENBLAS_SHARED_LIB ${CBLAS_INSTALL_DIR}/bin/openblas${CMAKE_SHARED_LIBRARY_SUFFIX})
ENDIF(NOT WIN32) ENDIF(NOT WIN32)
...@@ -53,7 +53,6 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data, ...@@ -53,7 +53,6 @@ bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
} else if (input_data.dtype == DistModelDataType::INT32) { } else if (input_data.dtype == DistModelDataType::INT32) {
input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place); input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place);
} else { } else {
// Q(fleet exe dev): for input/output, should we support fp16
LOG(ERROR) << "unsupported feed type " << input_data.dtype; LOG(ERROR) << "unsupported feed type " << input_data.dtype;
return false; return false;
} }
...@@ -113,14 +112,6 @@ std::string DistModelDTypeToString(DistModelDataType dtype) { ...@@ -113,14 +112,6 @@ std::string DistModelDTypeToString(DistModelDataType dtype) {
return "NOT SUPPORT DTYPE"; return "NOT SUPPORT DTYPE";
} }
bool IsPPFirstStage(const DistModelConfig &config) {
return config.local_rank - config.mp_degree < 0;
}
bool IsPPLastStage(const DistModelConfig &config) {
return config.local_rank + config.mp_degree >= config.nranks;
}
class DistModelTimer { class DistModelTimer {
public: public:
void tic() { tic_time = std::chrono::high_resolution_clock::now(); } void tic() { tic_time = std::chrono::high_resolution_clock::now(); }
...@@ -197,65 +188,34 @@ bool DistModel::PreparePlace() { ...@@ -197,65 +188,34 @@ bool DistModel::PreparePlace() {
} }
bool DistModel::CommInit() { bool DistModel::CommInit() {
// NOTE (Yuang Liu): The peer endpoints will be obtained with the assumption
// that mp part is always on inner side and pp part is always on outer side.
// TODO(fleet exe dev): The peer endpoints could be configured by users.
PADDLE_ENFORCE_EQ(
config_.pp_degree * config_.mp_degree, config_.nranks,
platform::errors::InvalidArgument(
"The mp_degree multiplies pp_degree is not equal with nranks"));
std::unique_ptr<framework::ProgramDesc> comm_init_program( std::unique_ptr<framework::ProgramDesc> comm_init_program(
new framework::ProgramDesc()); new framework::ProgramDesc());
framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0); framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0);
if (config_.mp_degree > 1) { std::vector<int64_t> &ring_ids =
PADDLE_ENFORCE_GE( config_.rank_to_ring_ids_[config_.local_rank];
config_.mp_ring_id, 0, int64_t order = 0;
platform::errors::InvalidArgument( std::string var_name_base = "comm_init_";
"mp ring id must be provided for inference under mp.")); for (int64_t ring_id : ring_ids) {
VLOG(3) << "Init comm group for mp."; VLOG(3) << "Init comm for ring id: " << ring_id;
int64_t ranks_in_group = config_.ring_id_to_ranks_[ring_id].size();
int64_t rank_in_group = 0;
std::vector<int64_t> &ranks = config_.ring_id_to_ranks_[ring_id];
for (int64_t rank : ranks) {
if (config_.local_rank == rank) {
break;
}
rank_in_group += 1;
}
std::vector<std::string> peer_endpoints; std::vector<std::string> peer_endpoints;
for (int64_t for (int64_t rank : ranks) {
idx = (config_.local_rank / config_.mp_degree) * config_.mp_degree, if (config_.local_rank == rank) {
i = 0;
i < config_.mp_degree; ++idx, ++i) {
if (config_.trainer_endpoints[idx] == config_.current_endpoint) {
continue; continue;
} }
peer_endpoints.emplace_back(config_.trainer_endpoints[idx]); peer_endpoints.emplace_back(config_.trainer_endpoints[rank]);
}
// get nranks in a mp group and inner group rank for local rank
int64_t mp_group_nranks = config_.nranks / config_.pp_degree;
int64_t mp_group_rank = config_.local_rank % config_.mp_degree;
InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints,
comm_init_block, config_.mp_ring_id);
}
if (config_.pp_degree > 1) {
VLOG(3) << "Init comm group for pp.";
if (!IsPPFirstStage(config_)) {
PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true,
platform::errors::InvalidArgument(
"pp upstream ring id must be provided for "
"non-first pp stage if inference under pp."));
// not the first pp stage, has upstream
std::vector<std::string> upstream_peer_endpoints;
upstream_peer_endpoints.emplace_back(
config_.trainer_endpoints[config_.local_rank - config_.mp_degree]);
InsertCommOp("pp_upstream_comm_id", 2, 1, upstream_peer_endpoints,
comm_init_block, config_.pp_upstream_ring_id);
}
if (!IsPPLastStage(config_)) {
PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true,
platform::errors::InvalidArgument(
"pp downstream ring id must be provided for "
"non-last pp stage if inference under pp."));
// not the last pp stage, has downstream
std::vector<std::string> downstream_peer_endpoints;
downstream_peer_endpoints.emplace_back(
config_.trainer_endpoints[config_.local_rank + config_.mp_degree]);
InsertCommOp("pp_downstream_comm_id", 2, 0, downstream_peer_endpoints,
comm_init_block, config_.pp_downstream_ring_id);
} }
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group,
rank_in_group, peer_endpoints, comm_init_block, ring_id);
order += 1;
} }
framework::NaiveExecutor e(place_); framework::NaiveExecutor e(place_);
e.CreateVariables(*comm_init_program, 0, true, scope_.get()); e.CreateVariables(*comm_init_program, 0, true, scope_.get());
...@@ -409,12 +369,7 @@ bool DistModel::LoadParameters() { ...@@ -409,12 +369,7 @@ bool DistModel::LoadParameters() {
bool DistModel::PrepareFleetExe() { bool DistModel::PrepareFleetExe() {
task_node_.reset(new TaskNode(program_.get(), config_.local_rank)); task_node_.reset(new TaskNode(program_.get(), config_.local_rank));
if (config_.local_rank - config_.mp_degree >= 0) { // With auto cut, there is no concept of pp, no need to add dependency.
task_node_->AddUpstreamTask(config_.local_rank - config_.mp_degree);
}
if (config_.local_rank + config_.mp_degree < config_.nranks) {
task_node_->AddDownstreamTask(config_.local_rank + config_.mp_degree);
}
task_node_->SetType("Compute"); task_node_->SetType("Compute");
task_node_->Init(); task_node_->Init();
executor_desc_ = FleetExecutorDesc(); executor_desc_ = FleetExecutorDesc();
...@@ -473,40 +428,13 @@ bool DistModel::PrepareFeedAndFetch() { ...@@ -473,40 +428,13 @@ bool DistModel::PrepareFeedAndFetch() {
} }
} }
if (config_.pp_degree == 1) { if (feeds_.size() == 0) {
if (feeds_.size() == 0) { LOG(ERROR) << "No feed ops in the inf program, please check the program.";
LOG(ERROR) << "No feed ops in the inf program, please check the program."; return false;
return false; }
} if (fetches_.size() == 0) {
if (fetches_.size() == 0) { LOG(ERROR) << "No fetch op in the inf program, please check the program.";
LOG(ERROR) << "No fetch op in the inf program, please check the program."; return false;
return false;
}
} else {
if (IsPPFirstStage(config_)) {
if (feeds_.size() == 0) {
LOG(ERROR) << "Feed ops are needed for the first pp stage.";
return false;
}
} else {
if (feeds_.size() > 0) {
LOG(WARNING) << "Feed op is found in the non-first stage of pp.";
} else {
LOG(INFO) << "No feed ops in non-first pp stage.";
}
}
if (IsPPLastStage(config_)) {
if (fetches_.size() == 0) {
LOG(WARNING) << "No fetch op was found in the last pp stage. Make sure "
"the result has been sent to frist pp stage.";
}
} else {
if (fetches_.size() > 0) {
LOG(WARNING) << "Fetch op is found in the non-last stage of pp.";
} else {
LOG(INFO) << "No fetch op in non-last pp stage.";
}
}
} }
return true; return true;
} }
...@@ -606,7 +534,6 @@ bool DistModel::FetchResult(const framework::LoDTensor &fetch, ...@@ -606,7 +534,6 @@ bool DistModel::FetchResult(const framework::LoDTensor &fetch,
bool DistModel::Run(const std::vector<DistModelTensor> &input_data, bool DistModel::Run(const std::vector<DistModelTensor> &input_data,
std::vector<DistModelTensor> *output_data) { std::vector<DistModelTensor> *output_data) {
// TODO(fleet exe dev): support pipeline inf mode
VLOG(3) << "DistModel run for once."; VLOG(3) << "DistModel run for once.";
DistModelTimer timer; DistModelTimer timer;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -47,12 +48,9 @@ struct DistModelConfig { ...@@ -47,12 +48,9 @@ struct DistModelConfig {
std::string current_endpoint{}; std::string current_endpoint{};
int64_t nranks{1}; int64_t nranks{1};
int64_t local_rank{0}; int64_t local_rank{0};
int64_t mp_degree{1};
int64_t pp_degree{1};
int64_t mp_ring_id{-1};
int64_t pp_upstream_ring_id{-1};
int64_t pp_downstream_ring_id{-1};
bool enable_timer{false}; bool enable_timer{false};
std::map<int64_t, std::vector<int64_t>> ring_id_to_ranks_{};
std::map<int64_t, std::vector<int64_t>> rank_to_ring_ids_{};
}; };
class DistModel { class DistModel {
......
...@@ -25,7 +25,7 @@ limitations under the License. */ ...@@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(fill_constant); USE_OP(fill_constant);
namespace paddle { namespace paddle {
......
...@@ -1227,11 +1227,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1227,11 +1227,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// Forward Function Body // Forward Function Body
// According to fwd_inputs_name_pos_map // According to fwd_inputs_name_pos_map
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>> std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>>
ins = ins =
{ {"X" , TrySyncToVars(X)}, { "Y" , TrySyncToVars(Y)} }; { {"X" , TrySyncToVars(X)}, { "Y" , TrySyncToVars(Y)} };
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>> std::map<std::string, std::vector<std::shared_ptr<egr::EagerVariable>>>
outs = outs =
{ {
{"Out0" , CreateVars(Out0Num)}, {"Out1" {"Out0" , CreateVars(Out0Num)}, {"Out1"
...@@ -1316,7 +1316,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1316,7 +1316,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_INS_MAP_TEMPLATE = const char* FWD_INS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> ins = { " "std::vector<std::shared_ptr<egr::EagerVariable>>> ins = { "
"%s };\n"; "%s };\n";
std::string ins_map_str = std::string ins_map_str =
paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str); paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str);
...@@ -1353,8 +1353,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1353,8 +1353,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (op_passing_outs_map[op_type].count(output_name)) { if (op_passing_outs_map[op_type].count(output_name)) {
const std::string output_var_name = output_name + "Var"; const std::string output_var_name = output_name + "Var";
// Pass Output from function argument(EagerTensor*/vector<EagerTensor*>&), // Pass Output from function
// in form of shared_ptr<EagerTensor>/vector<shared_ptr<EagerTensor>> // argument(EagerVariable*/vector<EagerVariable*>&),
// in form of shared_ptr<EagerVariable>/vector<shared_ptr<EagerVariable>>
if (output.duplicable()) { if (output.duplicable()) {
const char* FWD_NUM_ARG_TEMPLATE = const char* FWD_NUM_ARG_TEMPLATE =
", std::vector<paddle::experimental::Tensor*>& %s"; ", std::vector<paddle::experimental::Tensor*>& %s";
...@@ -1395,7 +1396,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1395,7 +1396,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
} else { } else {
const char* FWD_OUTS_CONTENT_TEMPLATE = const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", " "{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()." "{std::make_shared<egr::EagerVariable>(egr::Controller::Instance()."
"GenerateUniqueName())}},"; "GenerateUniqueName())}},";
outs_contents_str += outs_contents_str +=
paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name); paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name);
...@@ -1407,7 +1408,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1407,7 +1408,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const char* FWD_OUTS_MAP_TEMPLATE = const char* FWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> outs = { " "std::vector<std::shared_ptr<egr::EagerVariable>>> outs = { "
"%s };\n"; "%s };\n";
std::string outs_map_str = std::string outs_map_str =
paddle::string::Sprintf(FWD_OUTS_MAP_TEMPLATE, outs_contents_str); paddle::string::Sprintf(FWD_OUTS_MAP_TEMPLATE, outs_contents_str);
...@@ -1482,7 +1483,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1482,7 +1483,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
generated_function_body += out_tensor_str; generated_function_body += out_tensor_str;
} }
generated_function_body += "\n"; generated_function_body += "\n";
VLOG(6) << "Converted Output VarBase to EagerTensor(s)"; VLOG(6) << "Converted Output VarBase to EagerVariable(s)";
// [Generation] Handle core_ops_returns_info // [Generation] Handle core_ops_returns_info
core_ops_returns_info[op_type] = return_contents; core_ops_returns_info[op_type] = return_contents;
...@@ -1627,7 +1628,7 @@ static std::string GenerateSingleOpBase( ...@@ -1627,7 +1628,7 @@ static std::string GenerateSingleOpBase(
const char* BWD_INS_MAP_TEMPLATE = const char* BWD_INS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { " "std::vector<std::shared_ptr<egr::EagerVariable>>> %s = { "
"%s };\n"; "%s };\n";
std::string ins_map_str = std::string ins_map_str =
paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str); paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
...@@ -1704,7 +1705,7 @@ static std::string GenerateSingleOpBase( ...@@ -1704,7 +1705,7 @@ static std::string GenerateSingleOpBase(
} else { } else {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", " "{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance(" "{std::make_shared<egr::EagerVariable>(egr::Controller::Instance("
")." ")."
"GenerateUniqueName())}},"; "GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
...@@ -1723,7 +1724,7 @@ static std::string GenerateSingleOpBase( ...@@ -1723,7 +1724,7 @@ static std::string GenerateSingleOpBase(
const char* BWD_OUTS_MAP_TEMPLATE = const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { " "std::vector<std::shared_ptr<egr::EagerVariable>>> %s = { "
"%s };\n"; "%s };\n";
std::string outs_map_str = paddle::string::Sprintf( std::string outs_map_str = paddle::string::Sprintf(
BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str); BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
......
...@@ -40,36 +40,28 @@ ...@@ -40,36 +40,28 @@
* **/ * **/
namespace egr { namespace egr {
class EagerTensor final { class EagerVariable final {
public: public:
/* Default constructor and name constructor should only be used for contruct /* Default constructor and name constructor should only be used for contruct
* output and in fluid*/ * output and in fluid*/
EagerTensor() = default; EagerVariable() = default;
explicit EagerTensor(const std::string& name) : name_(name) {} explicit EagerVariable(const std::string& name) : name_(name) {}
explicit EagerTensor(const paddle::experimental::Tensor& tensor) explicit EagerVariable(const paddle::experimental::Tensor& tensor)
: name_(tensor.name()) { : name_(tensor.name()) {
if (tensor.defined()) { if (tensor.defined()) {
if (tensor.is_dense_tensor()) { if (tensor.is_dense_tensor()) {
auto* framework_tensor = ConstructVariableFromTensor(tensor);
var_.GetMutable<paddle::framework::LoDTensor>(); } else if (tensor.is_selected_rows()) {
// Contruct framework::Tensor from egr::EagerTensor ConstructVariableFromSelectedRows(tensor);
auto tensor_dense =
std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl());
PADDLE_ENFORCE_EQ((tensor_dense.get() && tensor_dense), true,
paddle::platform::errors::Fatal(
"Failed to Trans Tensor to EagerVariable since "
"we got Tensor with type DenseTensor, and we got "
"EagerVariable with another type."));
*framework_tensor = *tensor_dense;
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Unrecognized egr::EagerVariable type, only " "Unrecognized egr::EagerVariable type, only "
"DenseTensor and SelectedRows is supported for now.")); "DenseTensor and SelectedRows are supported for now."));
} }
} else { } else {
VLOG(6) << "Build Empty EagerTensor with name " << name_; VLOG(6) << "Build Empty EagerVariable with name " << name_;
} }
} }
...@@ -77,21 +69,20 @@ class EagerTensor final { ...@@ -77,21 +69,20 @@ class EagerTensor final {
std::shared_ptr<pten::TensorBase> GetTensorBase() { std::shared_ptr<pten::TensorBase> GetTensorBase() {
// Construct allocation only once. // Construct allocation only once.
if (var_.IsInitialized()) { if (var_.IsInitialized()) {
if (var_.IsType<paddle::framework::LoDTensor>()) { if (var_.IsType<paddle::framework::LoDTensor>() ||
return SetImplWithLegacyTensor<pten::DenseTensor>(); var_.IsType<paddle::framework::Tensor>()) {
} else if (var_.IsType<paddle::framework::Tensor>()) { return SetImplWithLegacyTensor();
return SetImplWithLegacyTensor<pten::DenseTensor>();
} else if (var_.IsType<pten::SelectedRows>()) { } else if (var_.IsType<pten::SelectedRows>()) {
return SetImplWithSelectedRows(); return SetImplWithLegacySelectedRows();
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Unable to fetch underlying tensor " "Unable to fetch underlying tensor "
"from EagerTensor, only LoDTensor and " "from EagerVariable, only LoDTensor and "
"Tensor are supported for now")); "Tensor are supported for now"));
} }
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Can not Sync EagerTensor %s whose paddle::framework::Variable is " "Can not Sync EagerVariable %s whose paddle::framework::Variable is "
"not initialized!", "not initialized!",
name())); name()));
} }
...@@ -107,23 +98,52 @@ class EagerTensor final { ...@@ -107,23 +98,52 @@ class EagerTensor final {
void set_name(const std::string& name) { name_ = name; } void set_name(const std::string& name) { name_ = name; }
private: private:
template <typename LEGACY_TYPE>
std::shared_ptr<pten::TensorBase> SetImplWithLegacyTensor() { std::shared_ptr<pten::TensorBase> SetImplWithLegacyTensor() {
const auto& framework_tensor = var_.Get<LEGACY_TYPE>(); const auto& framework_tensor = var_.Get<pten::DenseTensor>();
VLOG(8) << "Sync Var to tensor for: " << name(); VLOG(8) << "Sync Var to tensor for: " << name();
return std::make_shared<LEGACY_TYPE>(std::move(framework_tensor)); return std::make_shared<pten::DenseTensor>(framework_tensor);
} }
std::shared_ptr<pten::TensorBase> SetImplWithSelectedRows() { std::shared_ptr<pten::TensorBase> SetImplWithLegacySelectedRows() {
auto* selected_rows = var_.GetMutable<pten::SelectedRows>(); auto* framework_tensor = var_.GetMutable<pten::SelectedRows>();
auto res = std::make_shared<pten::SelectedRows>(selected_rows->rows_, VLOG(8) << "Sync SelectedRows to tensor for: " << name();
selected_rows->height_); auto res =
res->value_.reset(selected_rows->value_.release()); std::make_shared<pten::SelectedRows>(std::move(*framework_tensor));
res->id_to_index_ = std::move(selected_rows->id_to_index_); var_.Clear();
res->rwlock_.reset(selected_rows->rwlock_.release());
return res; return res;
} }
void ConstructVariableFromTensor(const paddle::experimental::Tensor& tensor) {
auto* framework_tensor = var_.GetMutable<pten::DenseTensor>();
// Contruct framework::Tensor from egr::EagerVariable
auto tensor_dense =
std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl());
PADDLE_ENFORCE_EQ(
(tensor_dense.get() && tensor_dense), true,
paddle::platform::errors::Fatal(
"Tensor %s does not hold pten::SelectedRows or pten::DenseTensor. "
"Or it holds empty impl, this should not happend since we should "
"treat all kinds of tensor as what they are.",
tensor.name()));
*framework_tensor = *tensor_dense;
}
void ConstructVariableFromSelectedRows(
const paddle::experimental::Tensor& tensor) {
auto* framework_tensor = var_.GetMutable<pten::SelectedRows>();
// Contruct framework::Tensor from egr::EagerVariable
auto tensor_dense =
std::dynamic_pointer_cast<pten::SelectedRows>(tensor.impl());
PADDLE_ENFORCE_EQ(
(tensor_dense.get() && tensor_dense), true,
paddle::platform::errors::Fatal(
"Tensor %s does not hold pten::SelectedRows or pten::DenseTensor. "
"Or it holds empty impl, this should not happend since we should "
"treat all kinds of tensor as what they are.",
tensor.name()));
*framework_tensor = std::move(*tensor_dense);
}
private: private:
std::string name_{""}; std::string name_{""};
paddle::framework::Variable var_; paddle::framework::Variable var_;
......
...@@ -78,9 +78,9 @@ void GradTensorHolder::add(size_t slot_id, size_t rank, ...@@ -78,9 +78,9 @@ void GradTensorHolder::add(size_t slot_id, size_t rank,
if (buffer_tensor.is_dense_tensor()) { if (buffer_tensor.is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, &buffer_tensor); paddle::imperative::SelectedRowsAddToTensor(t, &buffer_tensor);
} else { } else {
PADDLE_THROW(paddle::platform::errors::Fatal( buffer_tensor =
"We don't support Selected Rows merge for now, support it later " std::move(*paddle::imperative::SelectedRowsMerge<
"and make all kinds of grads can be merged.")); paddle::experimental::Tensor>(t, buffer_tensor));
} }
} }
} }
......
...@@ -115,7 +115,7 @@ TEST(Tensor, MemberFunction) { ...@@ -115,7 +115,7 @@ TEST(Tensor, MemberFunction) {
CHECK_EQ(tmp_autograd_meta_test->val_, 2); CHECK_EQ(tmp_autograd_meta_test->val_, 2);
} }
TEST(EagerTensor, Constructor) { TEST(EagerVariable, Constructor) {
paddle::experimental::Tensor t3; paddle::experimental::Tensor t3;
pten::DenseTensorMeta meta = pten::DenseTensorMeta( pten::DenseTensorMeta meta = pten::DenseTensorMeta(
pten::DataType::FLOAT32, paddle::framework::make_ddim({1, 2})); pten::DataType::FLOAT32, paddle::framework::make_ddim({1, 2}));
...@@ -134,7 +134,7 @@ TEST(EagerTensor, Constructor) { ...@@ -134,7 +134,7 @@ TEST(EagerTensor, Constructor) {
CHECK_EQ(t3.defined(), false); CHECK_EQ(t3.defined(), false);
t3.set_impl(dt); t3.set_impl(dt);
egr::EagerTensor et3 = egr::EagerTensor(t3); egr::EagerVariable et3 = egr::EagerVariable(t3);
VLOG(6) << "SyncToVar"; VLOG(6) << "SyncToVar";
CHECK_EQ(et3.Var().Get<paddle::framework::LoDTensor>().data<float>()[0], CHECK_EQ(et3.Var().Get<paddle::framework::LoDTensor>().data<float>()[0],
5.0f); 5.0f);
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/grad_tensor_holder.h" #include "paddle/fluid/eager/grad_tensor_holder.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/selected_rows.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
...@@ -102,3 +103,69 @@ TEST(GradTensorHolder, Interfaces) { ...@@ -102,3 +103,69 @@ TEST(GradTensorHolder, Interfaces) {
CHECK_EQ(holder_et0_ptr[0], 1.0f); CHECK_EQ(holder_et0_ptr[0], 1.0f);
CHECK_EQ(holder_et1_ptr[0], 30.0f); CHECK_EQ(holder_et1_ptr[0], 30.0f);
} }
TEST(GradTensorHolder, SelectedRowsMergeAdd) {
pten::CPUPlace cpu;
std::vector<int64_t> rows{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int64_t table_size = 10;
int64_t embedding_width = 10;
auto sr1 = std::make_shared<pten::SelectedRows>(rows, table_size);
auto sr2 = std::make_shared<pten::SelectedRows>(rows, table_size);
// initialize a sparse table 1
sr1->mutable_value()->Resize(
pten::framework::make_ddim({table_size, embedding_width}));
auto* data_sr1 = sr1->mutable_value()->mutable_data<float>(cpu);
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
data_sr1[i * embedding_width + j] = static_cast<float>(i);
}
}
// initialize a sparse table 2
sr2->mutable_value()->Resize(
pten::framework::make_ddim({table_size, embedding_width}));
auto* data_sr2 = sr2->mutable_value()->mutable_data<float>(cpu);
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
data_sr2[i * embedding_width + j] = static_cast<float>(i);
}
}
// new 2 pten::Tensor
paddle::experimental::Tensor t1(sr1);
paddle::experimental::Tensor t2(sr2);
// Constructor empty GradTensorHolder
GradSlotMeta slot_meta;
slot_meta.Init(1);
GradTensorHolder grad_tensor_holder =
GradTensorHolder({slot_meta, slot_meta});
// accumulation
grad_tensor_holder.add(0, 0, t1, false);
grad_tensor_holder.add(0, 0, t2, false);
// Buffers()
const auto& buffers = grad_tensor_holder.Buffers();
CHECK_EQ(static_cast<int>(buffers.size()), 2);
CHECK_EQ(static_cast<int>(buffers[0].size()), 1);
CHECK_EQ(static_cast<int>(buffers[1].size()), 1);
// operator[]
const auto& holder_et0 = grad_tensor_holder[0][0];
auto* tmp_buffer_tensor =
static_cast<pten::SelectedRows*>(holder_et0.impl().get());
auto* tmp_buffer_data_sr =
tmp_buffer_tensor->mutable_value()->mutable_data<float>(cpu);
// verify the MergeAdd result (accumulation result)
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
EXPECT_EQ(tmp_buffer_data_sr[i * embedding_width + j],
(static_cast<float>(i) + static_cast<float>(i)));
}
}
}
...@@ -176,6 +176,6 @@ TEST(Benchmark, EagerIntermediateMLPCPU) { ...@@ -176,6 +176,6 @@ TEST(Benchmark, EagerIntermediateMLPCPU) {
} }
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(matmul_v2); USE_OP(matmul_v2);
USE_OP(reduce_sum); USE_OP(reduce_sum);
...@@ -189,6 +189,6 @@ USE_OP_ITSELF(scale); ...@@ -189,6 +189,6 @@ USE_OP_ITSELF(scale);
USE_OP(matmul_v2); USE_OP(matmul_v2);
USE_OP(reduce_sum); USE_OP(reduce_sum);
USE_OP(reduce_sum_grad); USE_OP(reduce_sum_grad);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
...@@ -212,6 +212,6 @@ TEST(Benchmark, FluidMLPCPU) { ...@@ -212,6 +212,6 @@ TEST(Benchmark, FluidMLPCPU) {
} // namespace paddle } // namespace paddle
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(matmul_v2); USE_OP(matmul_v2);
USE_OP(reduce_sum); USE_OP(reduce_sum);
...@@ -249,6 +249,6 @@ USE_OP_ITSELF(scale); ...@@ -249,6 +249,6 @@ USE_OP_ITSELF(scale);
USE_OP(matmul_v2); USE_OP(matmul_v2);
USE_OP(reduce_sum); USE_OP(reduce_sum);
USE_OP(reduce_sum_grad); USE_OP(reduce_sum_grad);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP #endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h" #include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h"
...@@ -167,7 +168,7 @@ TEST(EagerUtils, PassStopGradient) { ...@@ -167,7 +168,7 @@ TEST(EagerUtils, PassStopGradient) {
TEST(EagerUtils, TrySyncToVar) { TEST(EagerUtils, TrySyncToVar) {
paddle::framework::DDim ddim = paddle::framework::make_ddim({2, 4, 4, 4}); paddle::framework::DDim ddim = paddle::framework::make_ddim({2, 4, 4, 4});
auto tensor = CreateTestCPUTensor(5.0f, ddim); auto tensor = CreateTestCPUTensor(5.0f, ddim);
std::vector<std::shared_ptr<egr::EagerTensor>> var_bases = { std::vector<std::shared_ptr<egr::EagerVariable>> var_bases = {
egr::EagerUtils::TrySyncToVar(tensor)}; egr::EagerUtils::TrySyncToVar(tensor)};
paddle::framework::Variable* var = var_bases[0]->MutableVar(); paddle::framework::Variable* var = var_bases[0]->MutableVar();
...@@ -187,7 +188,7 @@ TEST(EagerUtils, TrySyncToVars) { ...@@ -187,7 +188,7 @@ TEST(EagerUtils, TrySyncToVars) {
std::vector<paddle::experimental::Tensor> tensors = { std::vector<paddle::experimental::Tensor> tensors = {
CreateTestCPUTensor(1.0f, ddim), CreateTestCPUTensor(2.0f, ddim)}; CreateTestCPUTensor(1.0f, ddim), CreateTestCPUTensor(2.0f, ddim)};
std::vector<std::shared_ptr<egr::EagerTensor>> var_bases = std::vector<std::shared_ptr<egr::EagerVariable>> var_bases =
egr::EagerUtils::TrySyncToVars(tensors); egr::EagerUtils::TrySyncToVars(tensors);
{ {
...@@ -218,10 +219,32 @@ TEST(EagerUtils, TrySyncToVars) { ...@@ -218,10 +219,32 @@ TEST(EagerUtils, TrySyncToVars) {
TEST(EagerUtils, CreateVars) { TEST(EagerUtils, CreateVars) {
VLOG(6) << "Check CreateVars"; VLOG(6) << "Check CreateVars";
std::vector<std::shared_ptr<egr::EagerTensor>> outs = std::vector<std::shared_ptr<egr::EagerVariable>> outs =
egr::EagerUtils::CreateVars(2); egr::EagerUtils::CreateVars(2);
CHECK_EQ(outs.size(), size_t(2)); CHECK_EQ(outs.size(), size_t(2));
CHECK(outs[0]->Var().IsInitialized() == false); CHECK(outs[0]->Var().IsInitialized() == false);
} }
TEST(EagerUtils, GetGradAccumulationNode) {
VLOG(6) << "Check GetGradAccumulationNode";
paddle::experimental::Tensor t0("test_tensor");
ASSERT_EQ(egr::EagerUtils::GetGradAccumulationNode(t0), nullptr);
auto autograd_ptr0 = egr::EagerUtils::autograd_meta(&t0);
autograd_ptr0->SetStopGradient(true);
ASSERT_EQ(egr::EagerUtils::GetGradAccumulationNode(t0), nullptr);
autograd_ptr0->SetStopGradient(false);
auto res = std::dynamic_pointer_cast<egr::GradNodeAccumulation>(
egr::EagerUtils::GetGradAccumulationNode(t0));
ASSERT_TRUE(res != nullptr);
auto res2 = egr::EagerUtils::GetGradAccumulationNode(t0);
ASSERT_EQ(res2.get(), res.get());
autograd_ptr0->SetStopGradient(true);
auto res3 = egr::EagerUtils::GetGradAccumulationNode(t0);
ASSERT_EQ(res3, nullptr);
autograd_ptr0->SetStopGradient(false);
autograd_ptr0->SetGradNode(
std::make_shared<eager_test::GradTestNode>(1, 2.0, 3));
ASSERT_ANY_THROW(egr::EagerUtils::GetGradAccumulationNode(t0));
}
} // namespace egr } // namespace egr
...@@ -123,5 +123,5 @@ TEST(Generated, ElementwiseAdd) { ...@@ -123,5 +123,5 @@ TEST(Generated, ElementwiseAdd) {
} // namespace egr } // namespace egr
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(matmul_v2); USE_OP(matmul_v2);
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/pten/common/layout.h" #include "paddle/pten/common/layout.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/tensor_meta.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -131,17 +132,17 @@ void EagerUtils::SetOutRankWithSlot(AutogradMeta* target, size_t slot_id) { ...@@ -131,17 +132,17 @@ void EagerUtils::SetOutRankWithSlot(AutogradMeta* target, size_t slot_id) {
target->SetSingleOutRankWithSlot(slot_id, 0); target->SetSingleOutRankWithSlot(slot_id, 0);
} }
std::shared_ptr<egr::EagerTensor> EagerUtils::TrySyncToVar( std::shared_ptr<egr::EagerVariable> EagerUtils::TrySyncToVar(
const paddle::experimental::Tensor& tensor) { const paddle::experimental::Tensor& tensor) {
return std::make_shared<egr::EagerTensor>(tensor); return std::make_shared<egr::EagerVariable>(tensor);
} }
std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars( std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
const paddle::experimental::Tensor& tensor) { const paddle::experimental::Tensor& tensor) {
return {TrySyncToVar(tensor)}; return {TrySyncToVar(tensor)};
} }
std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars( std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
paddle::experimental::Tensor* tensor) { paddle::experimental::Tensor* tensor) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor, tensor,
...@@ -151,9 +152,9 @@ std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars( ...@@ -151,9 +152,9 @@ std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars(
return {TrySyncToVar(*tensor)}; return {TrySyncToVar(*tensor)};
} }
std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars( std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
const std::vector<paddle::experimental::Tensor*>& tensors) { const std::vector<paddle::experimental::Tensor*>& tensors) {
std::vector<std::shared_ptr<EagerTensor>> res; std::vector<std::shared_ptr<EagerVariable>> res;
size_t num = tensors.size(); size_t num = tensors.size();
res.reserve(num); res.reserve(num);
for (size_t i = 0; i < num; i++) { for (size_t i = 0; i < num; i++) {
...@@ -169,9 +170,9 @@ std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars( ...@@ -169,9 +170,9 @@ std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars(
return res; return res;
} }
std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars( std::vector<std::shared_ptr<egr::EagerVariable>> EagerUtils::TrySyncToVars(
const std::vector<paddle::experimental::Tensor>& tensors) { const std::vector<paddle::experimental::Tensor>& tensors) {
std::vector<std::shared_ptr<EagerTensor>> res; std::vector<std::shared_ptr<EagerVariable>> res;
size_t num = tensors.size(); size_t num = tensors.size();
res.reserve(num); res.reserve(num);
for (size_t i = 0; i < num; i++) { for (size_t i = 0; i < num; i++) {
...@@ -180,19 +181,19 @@ std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars( ...@@ -180,19 +181,19 @@ std::vector<std::shared_ptr<egr::EagerTensor>> EagerUtils::TrySyncToVars(
return res; return res;
} }
std::vector<std::shared_ptr<EagerTensor>> EagerUtils::CreateVars( std::vector<std::shared_ptr<EagerVariable>> EagerUtils::CreateVars(
const size_t num) { const size_t num) {
std::vector<std::shared_ptr<EagerTensor>> res; std::vector<std::shared_ptr<EagerVariable>> res;
res.reserve(num); res.reserve(num);
for (size_t i = 0; i < num; i++) { for (size_t i = 0; i < num; i++) {
res.emplace_back( res.emplace_back(
new EagerTensor(egr::Controller::Instance().GenerateUniqueName())); new EagerVariable(egr::Controller::Instance().GenerateUniqueName()));
} }
return res; return res;
} }
std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs( std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs(
const std::vector<std::shared_ptr<EagerTensor>>& outs) { const std::vector<std::shared_ptr<EagerVariable>>& outs) {
std::vector<paddle::experimental::Tensor> res; std::vector<paddle::experimental::Tensor> res;
res.reserve(outs.size()); res.reserve(outs.size());
for (const auto& out : outs) { for (const auto& out : outs) {
...@@ -209,7 +210,7 @@ std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs( ...@@ -209,7 +210,7 @@ std::vector<paddle::experimental::Tensor> EagerUtils::GetOutputs(
} }
paddle::experimental::Tensor EagerUtils::GetOutput( paddle::experimental::Tensor EagerUtils::GetOutput(
const std::shared_ptr<EagerTensor>& out) { const std::shared_ptr<EagerVariable>& out) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
out.get(), paddle::platform::errors::Fatal( out.get(), paddle::platform::errors::Fatal(
"Eager Tensor %s is null and cannot be copied. We " "Eager Tensor %s is null and cannot be copied. We "
...@@ -219,7 +220,7 @@ paddle::experimental::Tensor EagerUtils::GetOutput( ...@@ -219,7 +220,7 @@ paddle::experimental::Tensor EagerUtils::GetOutput(
return paddle::experimental::Tensor(out->GetTensorBase(), out->name()); return paddle::experimental::Tensor(out->GetTensorBase(), out->name());
} }
void EagerUtils::OverwriteOutputs(const std::shared_ptr<EagerTensor>& out, void EagerUtils::OverwriteOutputs(const std::shared_ptr<EagerVariable>& out,
paddle::experimental::Tensor* tensor) { paddle::experimental::Tensor* tensor) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor, paddle::platform::errors::Fatal( tensor, paddle::platform::errors::Fatal(
...@@ -231,7 +232,7 @@ void EagerUtils::OverwriteOutputs(const std::shared_ptr<EagerTensor>& out, ...@@ -231,7 +232,7 @@ void EagerUtils::OverwriteOutputs(const std::shared_ptr<EagerTensor>& out,
} }
void EagerUtils::OverwriteOutputs( void EagerUtils::OverwriteOutputs(
const std::vector<std::shared_ptr<EagerTensor>>& outs, const std::vector<std::shared_ptr<EagerVariable>>& outs,
const std::vector<paddle::experimental::Tensor*>& tensors) { const std::vector<paddle::experimental::Tensor*>& tensors) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
outs.size(), tensors.size(), outs.size(), tensors.size(),
...@@ -303,4 +304,41 @@ void EagerUtils::CheckAndRetainGrad( ...@@ -303,4 +304,41 @@ void EagerUtils::CheckAndRetainGrad(
} }
} }
std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor) {
auto* autograd_ptr = nullable_autograd_meta(tensor);
if (!autograd_ptr) {
return nullptr;
}
auto node_ptr = autograd_ptr->GetMutableGradNode();
if (node_ptr && node_ptr.get()) {
if (!autograd_ptr->StopGradient()) {
auto accumulation_ptr =
std::dynamic_pointer_cast<GradNodeAccumulation>(node_ptr);
if (accumulation_ptr) {
return accumulation_ptr;
} else {
// Current GradNode is not a egr::GradNodeAccumulation
PADDLE_THROW(paddle::platform::errors::Fatal(
"GetGradAccumulationNode should only be called on leaf tensor, but "
"target tensor: %s has GradNode which is not a "
"GradNodeAccumulation, and this should not happend unless target "
"tensor is modified by some ops and calling set history for it.",
tensor.name()));
}
} else {
// Current Tensor does not have grad since it's stop_gradient is true;
return nullptr;
}
} else {
if (!autograd_ptr->StopGradient()) {
VLOG(6) << "Add GradNodeAccumulation for tensor: " << tensor.name();
autograd_ptr->SetGradNode(std::make_shared<egr::GradNodeAccumulation>());
return autograd_ptr->GetMutableGradNode();
} else {
return nullptr;
}
}
}
} // namespace egr } // namespace egr
...@@ -88,7 +88,7 @@ class EagerUtils { ...@@ -88,7 +88,7 @@ class EagerUtils {
/** /**
* We have to use autograd_meta and multi_autograd_meta to initialize * We have to use autograd_meta and multi_autograd_meta to initialize
* autograd_meta for tensor, since we can't init it in * autograd_meta for tensor, since we can't init it in
* egr::EagerTensor's * egr::EagerVariable's
* constructor (it's abstract class there) * constructor (it's abstract class there)
* *
* **/ * **/
...@@ -151,34 +151,35 @@ class EagerUtils { ...@@ -151,34 +151,35 @@ class EagerUtils {
// Intermidate needed remove this once we don't need legacy // Intermidate needed remove this once we don't need legacy
// Inner Method // Inner Method
static std::shared_ptr<egr::EagerTensor> TrySyncToVar( static std::shared_ptr<egr::EagerVariable> TrySyncToVar(
const paddle::experimental::Tensor& tensor); const paddle::experimental::Tensor& tensor);
// Basic Input // Basic Input
static std::vector<std::shared_ptr<egr::EagerTensor>> TrySyncToVars( static std::vector<std::shared_ptr<egr::EagerVariable>> TrySyncToVars(
const paddle::experimental::Tensor& tensor); const paddle::experimental::Tensor& tensor);
// Basic Output // Basic Output
static std::vector<std::shared_ptr<egr::EagerTensor>> TrySyncToVars( static std::vector<std::shared_ptr<egr::EagerVariable>> TrySyncToVars(
paddle::experimental::Tensor* tensor); paddle::experimental::Tensor* tensor);
// Multi Output // Multi Output
static std::vector<std::shared_ptr<egr::EagerTensor>> TrySyncToVars( static std::vector<std::shared_ptr<egr::EagerVariable>> TrySyncToVars(
const std::vector<paddle::experimental::Tensor*>& tensors); const std::vector<paddle::experimental::Tensor*>& tensors);
// Multi Input // Multi Input
static std::vector<std::shared_ptr<egr::EagerTensor>> TrySyncToVars( static std::vector<std::shared_ptr<egr::EagerVariable>> TrySyncToVars(
const std::vector<paddle::experimental::Tensor>& tensors); const std::vector<paddle::experimental::Tensor>& tensors);
// Construct empty output // Construct empty output
static std::vector<std::shared_ptr<EagerTensor>> CreateVars(const size_t num); static std::vector<std::shared_ptr<EagerVariable>> CreateVars(
const size_t num);
// Construct Tensor From var // Construct Tensor From var
static std::vector<paddle::experimental::Tensor> GetOutputs( static std::vector<paddle::experimental::Tensor> GetOutputs(
const std::vector<std::shared_ptr<EagerTensor>>& outs); const std::vector<std::shared_ptr<EagerVariable>>& outs);
static paddle::experimental::Tensor GetOutput( static paddle::experimental::Tensor GetOutput(
const std::shared_ptr<EagerTensor>& out); const std::shared_ptr<EagerVariable>& out);
// Sync Back to origin output Tensor // Sync Back to origin output Tensor
static void OverwriteOutputs(const std::shared_ptr<EagerTensor>& out, static void OverwriteOutputs(const std::shared_ptr<EagerVariable>& out,
paddle::experimental::Tensor* tensor); paddle::experimental::Tensor* tensor);
static void OverwriteOutputs(const paddle::experimental::Tensor& out, static void OverwriteOutputs(const paddle::experimental::Tensor& out,
paddle::experimental::Tensor* tensor); paddle::experimental::Tensor* tensor);
static void OverwriteOutputs( static void OverwriteOutputs(
const std::vector<std::shared_ptr<EagerTensor>>& outs, const std::vector<std::shared_ptr<EagerVariable>>& outs,
const std::vector<paddle::experimental::Tensor*>& tensors); const std::vector<paddle::experimental::Tensor*>& tensors);
static void OverwriteOutputs( static void OverwriteOutputs(
const std::vector<paddle::experimental::Tensor>& outs, const std::vector<paddle::experimental::Tensor>& outs,
...@@ -188,6 +189,8 @@ class EagerUtils { ...@@ -188,6 +189,8 @@ class EagerUtils {
static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor); static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor);
static void CheckAndRetainGrad( static void CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor>& tensors); const std::vector<paddle::experimental::Tensor>& tensors);
static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor);
}; };
} // namespace egr } // namespace egr
...@@ -413,7 +413,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens ...@@ -413,7 +413,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens
cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference) cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference)
cc_test(infershape_utils_test SRCS infershape_utils_test.cc DEPS infershape_utils infermeta_utils meta_tensor)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
...@@ -458,4 +458,5 @@ if(WITH_GPU OR WITH_ROCM) ...@@ -458,4 +458,5 @@ if(WITH_GPU OR WITH_ROCM)
else() else()
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place) cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place)
endif() endif()
cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils)
cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor) cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor)
...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "gtest/gtest.h"
namespace pten { namespace pten {
namespace tests { namespace tests {
......
...@@ -100,6 +100,11 @@ struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> { ...@@ -100,6 +100,11 @@ struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
platform::errors::Unimplemented("platform::MLUPlace is not supported")); platform::errors::Unimplemented("platform::MLUPlace is not supported"));
} }
inline ::DLDevice operator()(const platform::CustomPlace &place) const {
PADDLE_THROW(platform::errors::Unimplemented(
"platform::CustomPlace is not supported"));
}
inline ::DLDevice operator()(const platform::CUDAPlace &place) const { inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
::DLDevice device; ::DLDevice device;
......
...@@ -494,6 +494,20 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx, ...@@ -494,6 +494,20 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
#else #else
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle")); platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle"));
#endif
} else if (platform::is_custom_place(place_)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for " << place_ << ".";
gc.reset(new CustomDeviceUnsafeFastGarbageCollector(place_,
max_memory_size));
} else {
VLOG(4) << "Use default stream gc for " << place_ << ".";
gc.reset(
new CustomDefaultStreamGarbageCollector(place_, max_memory_size));
}
#else
PADDLE_THROW(platform::errors::Unimplemented("No CustomDevice gc found"));
#endif #endif
} }
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#endif #endif
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
DECLARE_double(eager_delete_tensor_gb); DECLARE_double(eager_delete_tensor_gb);
DECLARE_double(memory_fraction_of_eager_deletion); DECLARE_double(memory_fraction_of_eager_deletion);
...@@ -202,6 +203,58 @@ void MLUStreamGarbageCollector::ClearCallback( ...@@ -202,6 +203,58 @@ void MLUStreamGarbageCollector::ClearCallback(
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CustomDefaultStreamGarbageCollector::CustomDefaultStreamGarbageCollector(
const platform::CustomPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void CustomDefaultStreamGarbageCollector::Wait() const {
static_cast<platform::CustomDeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}
void CustomDefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::CustomDeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}
CustomDeviceUnsafeFastGarbageCollector::CustomDeviceUnsafeFastGarbageCollector(
const platform::CustomPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}
void CustomDeviceUnsafeFastGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}
CustomStreamGarbageCollector::CustomStreamGarbageCollector(
const platform::CustomPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {
platform::DeviceGuard guard(place);
stream_.reset(new platform::stream::Stream);
stream_->Init(place);
callback_manager_.reset(new platform::CallbackManager(stream_.get()));
}
CustomStreamGarbageCollector::~CustomStreamGarbageCollector() {
platform::DeviceGuard guard(this->dev_ctx_->GetPlace());
stream_->Synchronize();
stream_->Destroy();
}
platform::stream::Stream *CustomStreamGarbageCollector::stream() const {
return stream_.get();
}
void CustomStreamGarbageCollector::Wait() const { callback_manager_->Wait(); }
void CustomStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback_manager_->AddCallback(callback);
}
#endif
int64_t GetEagerDeletionThreshold() { int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0 return FLAGS_eager_delete_tensor_gb < 0
? -1 ? -1
......
...@@ -200,6 +200,47 @@ class MLUStreamGarbageCollector : public GarbageCollector { ...@@ -200,6 +200,47 @@ class MLUStreamGarbageCollector : public GarbageCollector {
}; };
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
class CustomDefaultStreamGarbageCollector : public GarbageCollector {
public:
CustomDefaultStreamGarbageCollector(const platform::CustomPlace &place,
size_t max_memory_size);
void Wait() const override;
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
class CustomDeviceUnsafeFastGarbageCollector : public GarbageCollector {
public:
CustomDeviceUnsafeFastGarbageCollector(const platform::CustomPlace &place,
size_t max_memory_size);
protected:
void ClearCallback(const std::function<void()> &callback) override;
};
class CustomStreamGarbageCollector : public GarbageCollector {
public:
CustomStreamGarbageCollector(const platform::CustomPlace &place,
size_t max_memory_size);
~CustomStreamGarbageCollector();
void Wait() const override;
platform::stream::Stream *stream() const;
protected:
void ClearCallback(const std::function<void()> &callback) override;
private:
std::unique_ptr<platform::stream::Stream> stream_;
std::unique_ptr<platform::CallbackManager> callback_manager_;
};
#endif
template <typename Container> template <typename Container>
void GarbageCollector::Add(Container &&objs) { void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {}); Add(std::forward<Container>(objs), []() {});
......
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include <string>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
...@@ -303,13 +305,45 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx, ...@@ -303,13 +305,45 @@ pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
auto& attr = attr_reader.GetAttr(attr_name); auto& attr = attr_reader.GetAttr(attr_name);
if (std::type_index(attr.type()) == std::type_index(typeid(bool))) { if (std::type_index(attr.type()) == std::type_index(typeid(bool))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (std::type_index(attr.type()) == std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int64_t))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (std::type_index(attr.type()) == } else if (std::type_index(attr.type()) ==
std::type_index(typeid(float))) { std::type_index(typeid(float))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr)); infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
infer_meta_context.EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<bool>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<bool>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int64_t>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int64_t>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<float>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<float>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<double>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<double>, attr));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<std::string>))) {
infer_meta_context.EmplaceBackAttr(
BOOST_GET_CONST(std::vector<std::string>, attr));
} else { } else {
// do nothing, skip useless attrs now PADDLE_THROW(platform::errors::Unimplemented(
// TODO(chenweihang): support other attr type later and throw error "Unsupported attribute type is received when call "
// if attr is cannot parsed "InferShapeFunctor."));
} }
} else { } else {
// do nothing // do nothing
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/infermeta_utils.h"
namespace paddle {
namespace framework {
void TestInferMeta(bool bool_attr, int int_attr, int64_t int64_attr,
float float_attr, const std::string& str_attr,
const std::vector<bool>& vec_bool_attr,
const std::vector<int>& vec_int_attr,
const std::vector<int64_t>& vec_int64_attr,
const std::vector<float>& vec_float_attr,
const std::vector<double>& vec_double_attr,
const std::vector<std::string>& vec_str_attr) {
ASSERT_EQ(bool_attr, true);
ASSERT_EQ(int_attr, 10);
ASSERT_EQ(int64_attr, 100);
ASSERT_NEAR(float_attr, 3.14, 1e-6);
ASSERT_EQ(str_attr, "test");
ASSERT_EQ(vec_bool_attr.at(0), true);
ASSERT_EQ(vec_bool_attr.at(1), true);
ASSERT_EQ(vec_int_attr.at(0), 10);
ASSERT_EQ(vec_int_attr.at(1), 10);
ASSERT_EQ(vec_int64_attr.at(0), 100L);
ASSERT_EQ(vec_int64_attr.at(1), 100L);
ASSERT_NEAR(vec_float_attr.at(0), 3.14, 1e-6);
ASSERT_NEAR(vec_float_attr.at(1), 3.14, 1e-6);
ASSERT_NEAR(vec_double_attr.at(0), 3.1415, 1e-6);
ASSERT_NEAR(vec_double_attr.at(1), 3.1415, 1e-6);
ASSERT_EQ(vec_str_attr.at(0), "test_vec");
ASSERT_EQ(vec_str_attr.at(1), "test_vec");
}
class InferShapeUtilsTestOpMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<bool>("bool", "bool attr of test op");
AddAttr<int>("int", "int attr of test op");
AddAttr<int64_t>("int64", "int64 attr of test op");
AddAttr<float>("float", "float attr of test op");
AddAttr<std::string>("string", "string attr of test op");
AddAttr<std::vector<bool>>("vec_bool", "vec_bool attr of test op");
AddAttr<std::vector<int>>("vec_int", "vec_int attr of test op");
AddAttr<std::vector<int64_t>>("vec_int64", "vec_int attr of test op");
AddAttr<std::vector<float>>("vec_float", "vec_int attr of test op");
AddAttr<std::vector<double>>("vec_double", "vec_int attr of test op");
AddAttr<std::vector<std::string>>("vec_str", "vec_int attr of test op");
AddComment("This is test op");
}
};
class InferShapeUtilsTestOp : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
return OpKernelType(proto::VarType::FP32, ctx.GetPlace());
}
};
pten::KernelSignature InferShapeUtilsTestOpArgumentMapping(
const pten::ArgumentMappingContext& ctx) {
return pten::KernelSignature(
"infer_shape_utils_test", {},
{"bool", "int", "int64", "float", "string", "vec_bool", "vec_int",
"vec_int64", "vec_float", "vec_double", "vec_str"},
{});
}
} // namespace framework
} // namespace paddle
DELCARE_INFER_SHAPE_FUNCTOR(infer_shape_utils_test,
InferShapeUtilsTestInferShapeFunctor,
PT_INFER_META(paddle::framework::TestInferMeta));
REGISTER_OPERATOR(infer_shape_utils_test,
paddle::framework::InferShapeUtilsTestOp,
paddle::framework::InferShapeUtilsTestOpMaker,
InferShapeUtilsTestInferShapeFunctor);
TEST(InferShapeUtilsTest, ALL) {
paddle::framework::ProgramDesc prog;
paddle::framework::proto::BlockDesc proto_block;
paddle::framework::BlockDesc block_desc(&prog, &proto_block);
auto* op = block_desc.AppendOp();
op->SetType("infer_shape_utils_test");
paddle::framework::Attribute bool_attr(true);
op->SetAttr("bool", bool_attr);
paddle::framework::Attribute int_attr(10);
op->SetAttr("int", int_attr);
int64_t int64_val = 100;
paddle::framework::Attribute int64_attr(int64_val);
op->SetAttr("int64", int64_attr);
float float_value = 3.14;
paddle::framework::Attribute float_attr(float_value);
op->SetAttr("float", float_attr);
std::string str_value("test");
paddle::framework::Attribute str_attr(str_value);
op->SetAttr("string", str_attr);
std::vector<bool> vec_bool(2, true);
paddle::framework::Attribute vec_bool_attr = vec_bool;
op->SetAttr("vec_bool", vec_bool_attr);
std::vector<int> vec_int(2, 10);
paddle::framework::Attribute vec_int_attr = vec_int;
op->SetAttr("vec_int", vec_int_attr);
std::vector<int64_t> vec_int64(2, 100);
paddle::framework::Attribute vec_int64_attr = vec_int64;
op->SetAttr("vec_int64", vec_int64_attr);
std::cout << "after set vec_int64" << std::endl;
std::vector<float> vec_float(2, 3.14);
paddle::framework::Attribute vec_float_attr = vec_float;
op->SetAttr("vec_float", vec_float_attr);
std::vector<double> vec_double(2, 3.1415);
paddle::framework::Attribute vec_double_attr = vec_double;
op->SetAttr("vec_double", vec_double_attr);
std::vector<std::string> vec_str(2, "test_vec");
paddle::framework::Attribute vec_str_attr = vec_str;
op->SetAttr("vec_str", vec_str_attr);
pten::OpUtilsMap::Instance().InsertArgumentMappingFn(
"infer_shape_utils_test",
paddle::framework::InferShapeUtilsTestOpArgumentMapping);
op->InferShape(block_desc);
}
...@@ -103,6 +103,8 @@ target_link_libraries(generate_pass pass_desc_proto) ...@@ -103,6 +103,8 @@ target_link_libraries(generate_pass pass_desc_proto)
if(WITH_TENSORRT) if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference)
endif() endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h" #include "paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_optimizer_op_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/pten/core/kernel_factory.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -271,25 +272,41 @@ bool FuseOptimizerOpPass::OpWithKernelSupportCPUAndGPU( ...@@ -271,25 +272,41 @@ bool FuseOptimizerOpPass::OpWithKernelSupportCPUAndGPU(
if (op_type == "c_sync_calc_stream" || op_type == "c_sync_comm_stream") { if (op_type == "c_sync_calc_stream" || op_type == "c_sync_comm_stream") {
return true; return true;
} }
auto &all_kernels = OperatorWithKernel::AllOpKernels(); bool support_cpu = false;
auto it = all_kernels.find(op_type); bool support_gpu = false;
// skip op not has kernel auto &kernel_factory = pten::KernelFactory::Instance();
if (it != all_kernels.end()) { auto kernel_key_map =
bool support_cpu = false; kernel_factory.SelectKernelMap(pten::TransToPtenKernelName(op_type));
bool support_gpu = false; bool has_op_kernel = kernel_key_map.size() > 0 ? true : false;
for (auto &kernel_pair : it->second) { for (auto &kernel : kernel_key_map) {
if (platform::is_cpu_place(kernel_pair.first.place_)) { if (platform::is_gpu_place(
support_cpu = true; pten::TransToPtenPlace(kernel.first.backend()))) {
} support_gpu = true;
if (platform::is_gpu_place(kernel_pair.first.place_)) { } else if (platform::is_cpu_place(
support_gpu = true; pten::TransToPtenPlace(kernel.first.backend()))) {
support_cpu = true;
}
}
if (!support_cpu || !support_gpu) {
auto &all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
// skip op not has kernel
if (it != all_kernels.end()) {
has_op_kernel = true;
for (auto &kernel_pair : it->second) {
if (platform::is_cpu_place(kernel_pair.first.place_)) {
support_cpu = true;
} else if (platform::is_gpu_place(kernel_pair.first.place_)) {
support_gpu = true;
}
} }
} }
VLOG(6) << "Op check: " << op_type << ", support CPU: " << support_cpu
<< ", support GPU: " << support_gpu;
return support_cpu && support_gpu;
} }
return true;
VLOG(6) << "Op check: " << op_type << ", support CPU: " << support_cpu
<< ", support GPU: " << support_gpu;
return has_op_kernel ? (support_cpu && support_gpu) : true;
} }
bool FuseOptimizerOpPass::GradGeneratedOpKernelCheck( bool FuseOptimizerOpPass::GradGeneratedOpKernelCheck(
......
...@@ -26,7 +26,7 @@ ...@@ -26,7 +26,7 @@
USE_OP(mul); USE_OP(mul);
USE_OP(cinn_launch); USE_OP(cinn_launch);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
namespace paddle::framework { namespace paddle::framework {
using Name2VarInfoMap = using Name2VarInfoMap =
......
...@@ -23,8 +23,8 @@ ...@@ -23,8 +23,8 @@
USE_OP_ITSELF(scale); USE_OP_ITSELF(scale);
USE_OP(elementwise_mul); USE_OP(elementwise_mul);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(elementwise_add_grad); USE_OP_ITSELF(elementwise_add_grad);
DECLARE_double(eager_delete_tensor_gb); DECLARE_double(eager_delete_tensor_gb);
......
...@@ -29,7 +29,7 @@ USE_OP(batch_norm); ...@@ -29,7 +29,7 @@ USE_OP(batch_norm);
USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN); USE_OP_DEVICE_KERNEL(batch_norm, MKLDNN);
USE_OP(conv2d_transpose); USE_OP(conv2d_transpose);
USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN); USE_OP_DEVICE_KERNEL(conv2d_transpose, MKLDNN);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(gelu); USE_OP(gelu);
USE_OP_DEVICE_KERNEL(gelu, MKLDNN); USE_OP_DEVICE_KERNEL(gelu, MKLDNN);
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
USE_OP(softmax); USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN); USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(leaky_relu); USE_OP(leaky_relu);
USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN); USE_OP_DEVICE_KERNEL(leaky_relu, MKLDNN);
......
...@@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder"; ...@@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder";
constexpr char kEmbEltwiseLayernormPass[] = constexpr char kEmbEltwiseLayernormPass[] =
"embedding_eltwise_layernorm_fuse_pass_flag"; "embedding_eltwise_layernorm_fuse_pass_flag";
constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag";
constexpr char kPrelnEmbEltwiseLayernormPass[] =
"preln_embedding_eltwise_layernorm_fuse_pass_flag";
class Pass { class Pass {
public: public:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
const std::string& arg,
bool is_persist = false) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node =
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
if (is_persist) return node->assert_is_persistable_var();
return node;
}
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
const std::string& arg) {
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
PDNode* node = pattern->NewNode(name)
->assert_is_only_output_of_ops(embedding_ops)
->assert_is_op_input("elementwise_add", arg)
->AsIntermediate();
return node;
}
void PrelnEmbedding2Eltwise1Pattern::operator()() {
auto* lookup_table1_x =
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table2_x =
create_emb_vars(pattern, lookup_table2_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
auto* lookup_table2_w =
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table2 =
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
auto* lookup_table2_out =
create_emb_out_vars(pattern, lookup_table2_out_repr(), "Y");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
lookup_table2->LinksFrom({lookup_table2_x, lookup_table2_w})
.LinksTo({lookup_table2_out});
eltwise_add->LinksFrom({lookup_table1_out, lookup_table2_out})
.LinksTo({eltwise_add_out});
}
void PrelnEmbedding1Eltwise1Pattern::operator()() {
auto* lookup_table1_x =
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
auto* lookup_table1_w =
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
std::unordered_set<std::string> embedding_ops{"lookup_table",
"lookup_table_v2"};
auto* lookup_table1 =
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
auto* lookup_table1_out =
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_in = pattern->NewNode(eltwise_add_in_repr())
->assert_is_op_input("elementwise_add", "X")
->assert_is_op_output("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add");
lookup_table1->LinksFrom({lookup_table1_x, lookup_table1_w})
.LinksTo({lookup_table1_out});
eltwise_add->LinksFrom({lookup_table1_out, eltwise_add_in})
.LinksTo({eltwise_add_out});
}
void PrelnSkipLayerNorm::operator()() {
auto* eltwise_add =
pattern->NewNode(eltwise_add_repr())->assert_is_op("elementwise_add");
auto* eltwise_add_out = pattern->NewNode(eltwise_add_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X")
->assert_is_op_input("elementwise_add", "Y");
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y")
->AsOutput();
auto* layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto* layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto* layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto* layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
eltwise_add->LinksTo({eltwise_add_out});
layer_norm
->LinksFrom({eltwise_add_out, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo({layer_norm_out, layer_norm_mean_var, layer_norm_variance_var});
}
} // namespace patterns
int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(
Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
std::vector<std::vector<std::pair<Node*, Node*>>> start_pattern_in_nodes;
std::vector<Node*> start_pattern_out_node;
std::vector<std::unordered_set<Node*>> start_pattern_remove_nodes;
// Create pattern.
patterns::PrelnEmbedding2Eltwise1Pattern start_pattern(pattern,
name_scope + "/start");
start_pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_x, lookup_table2_x, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_w, lookup_table2_w, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2, lookup_table2, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table2_out, lookup_table2_out,
start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, start_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, start_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "Pass(PrelnEmbedding2Eltwise1Pattern) in op compat failed.";
return;
}
std::vector<std::pair<Node*, Node*>> ins;
ins.push_back(std::make_pair(lookup_table1_x, lookup_table1_w));
ins.push_back(std::make_pair(lookup_table2_x, lookup_table2_w));
start_pattern_in_nodes.push_back(ins);
start_pattern_out_node.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table2, lookup_table1_out,
lookup_table2_out, eltwise_add, eltwise_add_out});
start_pattern_remove_nodes.push_back(rm_nodes);
};
gpd(graph, handler);
std::vector<std::pair<Node*, Node*>> inner_pattern_ins;
std::vector<Node*> inner_pattern_tmp_in;
std::vector<Node*> inner_pattern_out;
std::vector<std::unordered_set<Node*>> inner_pattern_remove_nodes;
GraphPatternDetector gpd2;
auto* pattern2 = gpd2.mutable_pattern();
patterns::PrelnEmbedding1Eltwise1Pattern second_pattern(
pattern2, name_scope + "/second");
second_pattern();
auto handler2 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_x, lookup_table1_x, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_w, lookup_table1_w, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1, lookup_table1, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table1_out, lookup_table1_out,
second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_in, eltwise_add_in, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, second_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out, second_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "Pass(PrelnEmbedding1Eltwise1Pattern) in op compat failed.";
return;
}
auto in = std::make_pair(lookup_table1_x, lookup_table1_w);
inner_pattern_ins.push_back(in);
inner_pattern_tmp_in.push_back(eltwise_add_in);
inner_pattern_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({lookup_table1, lookup_table1_out, eltwise_add});
inner_pattern_remove_nodes.push_back(rm_nodes);
};
gpd2(graph, handler2);
std::vector<Node*> end_pattern_elt_out;
std::vector<Node*> end_pattern_scales;
std::vector<Node*> end_pattern_biases;
std::vector<Node*> end_pattern_out;
std::vector<Node*> end_patter_layernorms;
std::vector<Node*> end_patter_elementwise;
std::vector<std::unordered_set<Node*>> end_pattern_remove_nodes;
GraphPatternDetector gpd3;
auto* pattern3 = gpd3.mutable_pattern();
patterns::PrelnSkipLayerNorm skip_layernorm_pattern(pattern3,
name_scope + "/third");
skip_layernorm_pattern();
auto handler3 = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add, eltwise_add, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltwise_add_out, eltwise_add_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean,
skip_layernorm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
skip_layernorm_pattern);
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "Pass(PrelnSkipLayerNorm) in op compat failed.";
return;
}
end_pattern_elt_out.push_back(eltwise_add_out);
std::unordered_set<Node*> rm_nodes;
rm_nodes.insert({layer_norm, layer_norm_mean, layer_norm_variance});
end_pattern_remove_nodes.push_back(rm_nodes);
end_pattern_biases.push_back(layer_norm_bias);
end_pattern_scales.push_back(layer_norm_scale);
end_pattern_out.push_back(layer_norm_out);
end_patter_layernorms.push_back(layer_norm);
end_patter_elementwise.push_back(eltwise_add);
};
gpd3(graph, handler3);
if (start_pattern_in_nodes.empty() || end_pattern_elt_out.empty()) {
return 0;
}
// only reserve the subgraphs that in connected domains.
int fusion_count = 0;
// fusion_id for (i, k, js)
std::vector<std::pair<size_t, std::pair<size_t, std::vector<size_t>>>>
fusion_ids;
for (size_t i = 0; i < start_pattern_in_nodes.size(); ++i) {
Node* tmp = start_pattern_out_node[i];
Node* old_tmp = nullptr;
// get correct inner pattern node order.
std::vector<size_t> js;
while (tmp != old_tmp) {
old_tmp = tmp;
for (size_t j = 0; j < inner_pattern_tmp_in.size(); ++j) {
if (inner_pattern_tmp_in[j] == tmp) {
tmp = inner_pattern_out[j];
js.push_back(j);
break;
}
}
}
for (size_t k = 0; k < end_pattern_elt_out.size(); ++k) {
if (tmp == end_pattern_elt_out[k]) {
fusion_ids.push_back(std::make_pair(i, std::make_pair(k, js)));
break;
}
}
}
for (size_t num = 0; num < fusion_ids.size(); ++num) {
int i = fusion_ids[num].first;
int k = fusion_ids[num].second.first;
std::vector<size_t> js = fusion_ids[num].second.second;
std::vector<std::string> ids;
std::vector<std::string> embs;
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
ids.push_back(start_pattern_in_nodes[i][iter].first->Name());
embs.push_back(start_pattern_in_nodes[i][iter].second->Name());
}
for (size_t iter = 0; iter < js.size(); ++iter) {
ids.push_back(inner_pattern_ins[js[iter]].first->Name());
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
OpDesc new_op_desc;
new_op_desc.SetType("fused_preln_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
new_op_desc.SetInput("Embs", embs);
new_op_desc.SetInput("WordId", {ids[0]});
new_op_desc.SetInput("PosId", {ids[1]});
if (ids.size() > 2) {
new_op_desc.SetInput("SentId", {ids[2]});
}
new_op_desc.SetInput("WordEmbedding", {embs[0]});
new_op_desc.SetInput("PosEmbedding", {embs[1]});
if (embs.size() > 2) {
new_op_desc.SetInput("SentEmbedding", {embs[2]});
}
new_op_desc.SetInput("Bias", {end_pattern_biases[k]->Name()});
new_op_desc.SetInput("Scale", {end_pattern_scales[k]->Name()});
new_op_desc.SetOutput("Out_0", {end_pattern_out[k]->Name()});
new_op_desc.SetOutput("Out_1", {inner_pattern_out[k]->Name()});
new_op_desc.SetAttr("epsilon",
end_patter_layernorms[k]->Op()->GetAttr("epsilon"));
if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold") &&
end_patter_elementwise[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_0_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
new_op_desc.SetAttr(
"out_1_threshold",
end_patter_elementwise[k]->Op()->GetAttr("out_threshold"));
}
auto* preln_embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
for (size_t iter = 0; iter < start_pattern_in_nodes[i].size(); ++iter) {
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].first,
preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(start_pattern_in_nodes[i][iter].second,
preln_embedding_eltwise_layernorm);
}
for (size_t iter = 0; iter < js.size(); ++iter) {
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].first,
preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(inner_pattern_ins[js[iter]].second,
preln_embedding_eltwise_layernorm);
}
IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm);
IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]);
IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, inner_pattern_out[k]);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
marked_nodes.insert(start_pattern_remove_nodes[i].begin(),
start_pattern_remove_nodes[i].end());
marked_nodes.insert(end_pattern_remove_nodes[k].begin(),
end_pattern_remove_nodes[k].end());
for (size_t iter = 0; iter < js.size(); ++iter) {
marked_nodes.insert(inner_pattern_remove_nodes[js[iter]].begin(),
inner_pattern_remove_nodes[js[iter]].end());
}
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
}
return fusion_count;
}
PrelnEmbeddingEltwiseLayerNormFusePass::
PrelnEmbeddingEltwiseLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
void PrelnEmbeddingEltwiseLayerNormFusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count =
PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion(graph, name_scope_);
if (fusion_count > 0) {
graph->Set(kPrelnEmbEltwiseLayernormPass, new bool(true));
}
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_embedding_eltwise_layernorm_fuse_pass,
paddle::framework::ir::PrelnEmbeddingEltwiseLayerNormFusePass);
REGISTER_PASS_CAPABILITY(preln_embedding_eltwise_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("lookup_table", 1)
.LE("lookup_table_v2", 1)
.LE("elementweise_add", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// detect start pattern.
//
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
struct PrelnEmbedding2Eltwise1Pattern : public PatternBase {
PrelnEmbedding2Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnembedding2_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table2_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table2_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table2);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(lookup_table2_out);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect repeats inner pattern
//
// elt_out_var in_var emb
// \ | |
// \ lookup_table
// \ |
// \ lkt_var
// \ /
// elementwise_add
// | |
// elementwise_add elt_out_var
//
struct PrelnEmbedding1Eltwise1Pattern : public PatternBase {
PrelnEmbedding1Eltwise1Pattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnembedding1_eltwise1") {}
void operator()();
PATTERN_DECL_NODE(lookup_table1_x);
PATTERN_DECL_NODE(lookup_table1_w);
PATTERN_DECL_NODE(lookup_table1);
PATTERN_DECL_NODE(lookup_table1_out);
PATTERN_DECL_NODE(eltwise_add_in);
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
};
// detect end pattern
//
// elementwise_add
// | |
// | elt_out_var
// | scale | bias
// | \ | /
// elementwise_add layer_norm
//
struct PrelnSkipLayerNorm : public PatternBase {
PrelnSkipLayerNorm(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "Prelnskip_layernorm") {}
void operator()();
PATTERN_DECL_NODE(eltwise_add);
PATTERN_DECL_NODE(eltwise_add_out);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
// Delete the mean and var nodes in the graph.
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
} // namespace patterns
// The PrelnEmbeddingEltwiseLayerNormFusePass detect the following pattern:
//
// inputs operator output
// --------------------------------------------------------------------
// (word, weights_0) lookup_table -> word_emb
// (pos, weights_1) lookup_table -> pos_emb
// (sent, weights_2) lookup_table -> sent_emb
// (word_emb, pos_emb) elementweise_add -> elementwise_out_0
// (elemtwise_out_0, sent_emb) elementweise_add -> elementwise_out_1
// (elementwise_out_1, scale, bias) layer_norm -> layer_norm_out
//
// and then convert the corresponding subgraph to:
//
// (word, pos, sent, weights_0, weights_1, weights_2,
// scale, baias) Prelnembedding_eltwise_layernorm -> layer_norm_out +
// elementwise_add_out
//
//
// in_var emb_var in_var emb_var in_var emb_var in_var emb_var
// | | | | | | | |
// lookup_table lookup_table lookup_table ... lookup_table
// | | | |
// lkt_var lkt_var lkt_var lkt_var
// \ / | ... |
// elementwise_add | |
// \ / |
// elementwise_add |
// | |
// elt_var /
// \ /
// elementwise_add
// | |
// elementwise_add layer_norm
class PrelnEmbeddingEltwiseLayerNormFusePass : public FusePassBase {
public:
PrelnEmbeddingEltwiseLayerNormFusePass();
virtual ~PrelnEmbeddingEltwiseLayerNormFusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
int BuildFusion(Graph* graph, const std::string& name_scope
/*const Scope* scope*/) const;
const std::string name_scope_{"preln_embedding_eltwise_layernorm_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct PrelnSkipLayerNorm : public PatternBase {
PrelnSkipLayerNorm(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "preln_skip_layernorm") {}
void operator()(PDNode *x, PDNode *y);
// declare operator node's name
PATTERN_DECL_NODE(fused_skipe_layernorm);
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(layer_norm);
// declare variable node's name
PATTERN_DECL_NODE(
elementwise_out); // (elementwise_input_x,elementwise_input_y) ->
// elementwise_out
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
};
void PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) {
// Create nodes for elementwise add op.
x->assert_is_op_input("elementwise_add", "X");
y->assert_is_op_input("elementwise_add", "Y");
auto *elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add");
auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr())
->assert_is_op_output("elementwise_add")
->assert_is_op_input("layer_norm", "X")
->assert_is_op_input("elementwise_add", "Y");
// Add links for elementwise_add op.
elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var});
// Create nodes for layer_norm op.
auto *layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto *layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
// Add links for layer_norm op.
layer_norm
->LinksFrom(
{elementwise_out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
}
} // namespace patterns
void PrelnSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("preln_skip_layernorm_fuse", graph);
int found_subgraph_count = 0;
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode("preln_skip_layernorm_fuse/x")
->AsInput()
->assert_is_op_input("elementwise_add", "X")
->assert_var_not_persistable();
auto *y = gpd.mutable_pattern()
->NewNode("preln_skip_layernorm_fuse/y")
->AsInput()
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable();
patterns::PrelnSkipLayerNorm fused_pattern(gpd.mutable_pattern(),
"preln_skip_layernorm_fuse");
fused_pattern(x, y);
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *graph) {
if (subgraph.count(x) <= 0 || subgraph.count(y) <= 0) {
LOG(WARNING) << "The subgraph is empty.";
return;
}
if (!IsCompat(subgraph, graph)) {
LOG(WARNING) << "preln_skip_layernorm pass in op compat failed.";
return;
}
VLOG(4) << "handle PrelnSkipLayerNorm fuse";
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_bias, layer_norm_bias, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale, layer_norm_scale,
fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_mean, layer_norm_mean, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_variance, layer_norm_variance,
fused_pattern);
std::unordered_set<const Node *> del_node_set;
// Create an PrelnSkipLayerNorm op node
OpDesc new_desc;
new_desc.SetType("preln_skip_layernorm");
// inputs
new_desc.SetInput("X", {subgraph.at(x)->Name()});
new_desc.SetInput("Y", {subgraph.at(y)->Name()});
new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()});
if (elementwise->Op()->HasAttr("out_threshold") &&
layer_norm->Op()->HasAttr("out_threshold")) {
new_desc.SetAttr("enable_int8", true);
new_desc.SetAttr("out_0_threshold",
layer_norm->Op()->GetAttr("out_threshold"));
new_desc.SetAttr("out_1_threshold",
elementwise->Op()->GetAttr("out_threshold"));
}
// outputs
new_desc.SetOutput("Out_0", {layer_norm_out->Name()});
new_desc.SetOutput("Out_1", {elementwise_out->Name()});
// attrs
new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon"));
new_desc.SetAttr("begin_norm_axis",
layer_norm->Op()->GetAttr("begin_norm_axis"));
auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied.
del_node_set.insert(elementwise);
del_node_set.insert(layer_norm);
del_node_set.insert(layer_norm_mean);
del_node_set.insert(layer_norm_variance);
GraphSafeRemoveNodes(graph, del_node_set);
IR_NODE_LINK_TO(subgraph.at(x), fused_node);
IR_NODE_LINK_TO(subgraph.at(y), fused_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_node);
IR_NODE_LINK_TO(fused_node, layer_norm_out);
IR_NODE_LINK_TO(fused_node, elementwise_out);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(preln_skip_layernorm_fuse_pass,
paddle::framework::ir::PrelnSkipLayerNormFusePass);
REGISTER_PASS_CAPABILITY(preln_skip_layernorm_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.EQ("layer_norm", 0));
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
// | | | |
// other_op1 other_op2 other_op1 other_op2
// | | fuse \ /
// |------elementwise_add -> skip_layernorm
// | | | |
// other_op4 layer_norm other_op4 other_op3
// |
// other_op3
class Graph;
class PrelnSkipLayerNormFusePass : public FusePassBase {
public:
PrelnSkipLayerNormFusePass() {
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, -1})
.End();
AddOpCompat(OpCompat("layer_norm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("Mean")
.IsTensor()
.End()
.AddOutput("Variance")
.IsTensor()
.End()
.AddAttr("epsilon")
.IsNumGE(0.0f)
.IsNumLE(0.001f)
.End()
.AddAttr("begin_norm_axis")
.IsNumGT(0)
.End();
}
virtual ~PrelnSkipLayerNormFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -67,4 +67,4 @@ TEST(NaiveExecutor, Basic) { ...@@ -67,4 +67,4 @@ TEST(NaiveExecutor, Basic) {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
...@@ -25,12 +25,12 @@ USE_OP(fill_constant); ...@@ -25,12 +25,12 @@ USE_OP(fill_constant);
USE_OP(uniform_random); USE_OP(uniform_random);
USE_OP(lookup_table); USE_OP(lookup_table);
USE_OP(transpose2); USE_OP(transpose2);
USE_OP(reshape2); USE_OP_ITSELF(reshape2);
USE_OP(split); USE_OP(split);
USE_OP(slice); USE_OP(slice);
USE_OP(concat); USE_OP(concat);
USE_OP(matmul); USE_OP(matmul);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(sigmoid); USE_OP(sigmoid);
USE_OP(tanh); USE_OP(tanh);
USE_OP(elementwise_mul); USE_OP(elementwise_mul);
...@@ -39,9 +39,9 @@ USE_OP(reduce_mean); ...@@ -39,9 +39,9 @@ USE_OP(reduce_mean);
USE_OP(reduce_sum); USE_OP(reduce_sum);
USE_OP(reduce_sum_grad); USE_OP(reduce_sum_grad);
USE_OP(reduce_mean_grad); USE_OP(reduce_mean_grad);
USE_OP(reshape2_grad); USE_OP_ITSELF(reshape2_grad);
USE_OP(softmax_with_cross_entropy_grad); USE_OP(softmax_with_cross_entropy_grad);
USE_OP(elementwise_add_grad); USE_OP_ITSELF(elementwise_add_grad);
USE_OP(matmul_grad); USE_OP(matmul_grad);
USE_OP(square); USE_OP(square);
USE_OP(transpose2_grad); USE_OP(transpose2_grad);
......
...@@ -47,10 +47,20 @@ size_t OpKernelType::Hash::operator()(const OpKernelType& key) const { ...@@ -47,10 +47,20 @@ size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
"Too many OpKernel attribute values, expected maximum " "Too many OpKernel attribute values, expected maximum "
"value is 64, received value is %d.", "value is 64, received value is %d.",
cur_loc)); cur_loc));
#ifdef PADDLE_WITH_CUSTOM_DEVICE
std::hash<int> hasher;
size_t seed =
hasher(place + data_type + data_layout + library_type + customized_value);
if (platform::is_custom_place(key.place_)) {
seed ^= std::hash<std::string>{}(key.place_.GetDeviceType()) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 4;
}
return seed;
#else
std::hash<int> hasher; std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type + return hasher(place + data_type + data_layout + library_type +
customized_value); customized_value);
#endif
} }
bool OpKernelType::operator==(const OpKernelType& o) const { bool OpKernelType::operator==(const OpKernelType& o) const {
......
...@@ -29,6 +29,7 @@ limitations under the License. */ ...@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/transfer_scope_cache.h" #include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h" #include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar.h"
...@@ -244,6 +245,15 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { ...@@ -244,6 +245,15 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
#else #else
auto dev_id = place.device; auto dev_id = place.device;
platform::SetMLUDeviceId(dev_id); platform::SetMLUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
platform::DeviceManager::SetDevice(place);
#endif #endif
} }
...@@ -1326,8 +1336,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1326,8 +1336,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto& dev_ctx = ctx.device_context();
auto expected_kernel_key = this->GetExpectedKernelType(ctx); auto expected_kernel_key = this->GetExpectedKernelType(ctx);
if (HasAttr("op_device")) { if (HasAttr("op_device")) {
if (Attr<std::string>("op_device") == "cpu") { if (Attr<std::string>("op_device") == "cpu") {
...@@ -1344,12 +1352,20 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -1344,12 +1352,20 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
} }
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time. // will be executed and a warning will be given at the same time.
expected_kernel_key.place_ = platform::CPUPlace();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (SupportGPU()) { if (SupportGPU()) {
auto& dev_ctx = ctx.device_context();
expected_kernel_key.place_ = dev_ctx.GetPlace(); expected_kernel_key.place_ = dev_ctx.GetPlace();
} else if (SupportNPU()) { }
#endif
#ifdef PADDLE_WITH_ASCEND_CL
if (SupportNPU()) {
auto& dev_ctx = ctx.device_context();
expected_kernel_key.place_ = dev_ctx.GetPlace(); expected_kernel_key.place_ = dev_ctx.GetPlace();
} else { }
expected_kernel_key.place_ = platform::CPUPlace(); #endif
if (platform::is_cpu_place(expected_kernel_key.place_)) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
<< "Op(" << type_ << "Op(" << type_
<< ") has no CUDA implementation. It will be assigned to CPUPlace."; << ") has no CUDA implementation. It will be assigned to CPUPlace.";
...@@ -1924,12 +1940,10 @@ Scope* OperatorWithKernel::PreparePtenData( ...@@ -1924,12 +1940,10 @@ Scope* OperatorWithKernel::PreparePtenData(
for (size_t i = 0; i < input_defs.size(); ++i) { for (size_t i = 0; i < input_defs.size(); ++i) {
auto& in_def = input_defs.at(i); auto& in_def = input_defs.at(i);
auto it = ctx->inputs.find(input_names[i]); if (ctx->inputs.find(input_names[i]) == ctx->inputs.end()) {
if (it == ctx->inputs.end()) {
continue; continue;
} }
auto& ins_vector = ctx->inputs.at(input_names[i]);
auto& ins_vector = it->second;
auto& name_vec = name_map.at(input_names[i]); auto& name_vec = name_map.at(input_names[i]);
bool should_skip_input = bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(input_names[i]) > 0; no_buffer_ins && no_buffer_ins->count(input_names[i]) > 0;
...@@ -1940,7 +1954,6 @@ Scope* OperatorWithKernel::PreparePtenData( ...@@ -1940,7 +1954,6 @@ Scope* OperatorWithKernel::PreparePtenData(
if (var == nullptr || !VarIsTensor(*var)) { if (var == nullptr || !VarIsTensor(*var)) {
continue; continue;
} }
auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
// When no_buffer_ins then checking of Tensor::holder_ is // When no_buffer_ins then checking of Tensor::holder_ is
...@@ -2165,6 +2178,8 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -2165,6 +2178,8 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) { std::type_index(typeid(std::string))) {
pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); pt_kernel_context->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
......
...@@ -661,6 +661,6 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) { ...@@ -661,6 +661,6 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) {
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
USE_OP(mul); USE_OP(mul);
USE_OP(relu); USE_OP(relu);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(relu_grad); USE_OP(relu_grad);
USE_OP(elementwise_add_grad); USE_OP_ITSELF(elementwise_add_grad);
...@@ -88,7 +88,7 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -88,7 +88,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
if (cache_by_struct_.count(cur_key_by_struct) != 0) { if (cache_by_struct_.count(cur_key_by_struct) != 0) {
exist = true; exist = true;
cache_by_address_[cur_key_by_address] = cache_by_address_[cur_key_by_address] =
cache_by_struct_.at(cur_key_by_struct).get(); cache_by_struct_.at(cur_key_by_struct);
} }
} }
} }
...@@ -98,12 +98,13 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -98,12 +98,13 @@ const CinnCompiledObject& CinnCompiler::Compile(
CompileGraph(graph, input_tensors, target, compiled_num, stream); CompileGraph(graph, input_tensors, target, compiled_num, stream);
pten::AutoWRLock w_guard{&rwlock_}; pten::AutoWRLock w_guard{&rwlock_};
if (!cache_by_struct_.count(cur_key_by_struct)) { if (!cache_by_struct_.count(cur_key_by_struct)) {
cache_by_address_[cur_key_by_address] = compiled_res.get(); cache_by_address_[cur_key_by_address] = compiled_num;
cache_by_struct_[cur_key_by_struct] = std::move(compiled_res); cache_by_struct_[cur_key_by_struct] = compiled_num;
index2cache_.emplace(compiled_num, std::move(compiled_res));
} }
} }
pten::AutoRDLock guard{&rwlock_}; pten::AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_by_address_[cur_key_by_address]; const auto& cached_boj = *index2cache_[cache_by_address_[cur_key_by_address]];
return cached_boj; return cached_boj;
} }
...@@ -115,6 +116,15 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -115,6 +116,15 @@ const CinnCompiledObject& CinnCompiler::Compile(
return Compile(graph, input_tensors, target, stream); return Compile(graph, input_tensors, target, stream);
} }
const CinnCompiledObject& CinnCompiler::GetCompiledObject(
int64_t cached_index) const {
auto res = index2cache_.find(cached_index);
PADDLE_ENFORCE_NE(res, index2cache_.end(),
platform::errors::InvalidArgument(
"Index(%ld) not found in cache", cached_index));
return *res->second;
}
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key; std::string graph_key;
ProgramDesc program; ProgramDesc program;
...@@ -202,6 +212,7 @@ void CinnCompiler::Clear() { ...@@ -202,6 +212,7 @@ void CinnCompiler::Clear() {
graphs_.clear(); graphs_.clear();
cache_by_address_.clear(); cache_by_address_.clear();
cache_by_struct_.clear(); cache_by_struct_.clear();
index2cache_.clear();
} }
real_compiled_num_.store(0); real_compiled_num_.store(0);
} }
...@@ -240,6 +251,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -240,6 +251,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
compiled_obj->launch_context = compiled_obj->launch_context =
std::make_unique<operators::details::CinnLaunchContext>( std::make_unique<operators::details::CinnLaunchContext>(
compiled_obj->paddle2cinn_varmap, compiled_obj->scope); compiled_obj->paddle2cinn_varmap, compiled_obj->scope);
compiled_obj->cached_index = compiled_num;
return compiled_obj; return compiled_obj;
} }
......
...@@ -53,6 +53,7 @@ struct CinnCompiledObject { ...@@ -53,6 +53,7 @@ struct CinnCompiledObject {
std::shared_ptr<::cinn::hlir::framework::Scope> scope; std::shared_ptr<::cinn::hlir::framework::Scope> scope;
std::unordered_map<std::string, std::string> paddle2cinn_varmap; std::unordered_map<std::string, std::string> paddle2cinn_varmap;
std::unique_ptr<operators::details::CinnLaunchContext> launch_context; std::unique_ptr<operators::details::CinnLaunchContext> launch_context;
std::int64_t cached_index;
}; };
// Entrance to use CINN. // Entrance to use CINN.
...@@ -76,6 +77,8 @@ class CinnCompiler { ...@@ -76,6 +77,8 @@ class CinnCompiler {
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target, void* stream = nullptr); const ::cinn::common::Target& target, void* stream = nullptr);
const CinnCompiledObject& GetCompiledObject(int64_t cached_index) const;
std::string AddGraph(std::unique_ptr<ir::Graph> graph); std::string AddGraph(std::unique_ptr<ir::Graph> graph);
const ir::Graph& FindGraph(const std::string& graph_key) const; const ir::Graph& FindGraph(const std::string& graph_key) const;
...@@ -101,12 +104,12 @@ class CinnCompiler { ...@@ -101,12 +104,12 @@ class CinnCompiler {
void* stream = nullptr) const; void* stream = nullptr) const;
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_; std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKeyByAddress, CinnCompiledObject*, std::unordered_map<CinnCacheKeyByAddress, std::int64_t, CinnCacheKey::Hash>
CinnCacheKey::Hash>
cache_by_address_; cache_by_address_;
std::unordered_map<CinnCacheKeyByStructure, std::unordered_map<CinnCacheKeyByStructure, std::int64_t, CinnCacheKey::Hash>
std::unique_ptr<CinnCompiledObject>, CinnCacheKey::Hash>
cache_by_struct_; cache_by_struct_;
std::unordered_map<std::int64_t, std::unique_ptr<CinnCompiledObject>>
index2cache_;
std::atomic_int64_t real_compiled_num_{0}; std::atomic_int64_t real_compiled_num_{0};
mutable pten::RWLock rwlock_; mutable pten::RWLock rwlock_;
......
...@@ -270,13 +270,20 @@ TEST(CinnCompilerTest, Compile) { ...@@ -270,13 +270,20 @@ TEST(CinnCompilerTest, Compile) {
auto compile_fn = [&](const Target& target) { auto compile_fn = [&](const Target& target) {
const auto& compiled_obj = const auto& compiled_obj =
cinn_compiler->Compile(compiling_graph, input_tensors, target); cinn_compiler->Compile(compiling_graph, input_tensors, target);
ASSERT_NE(compiled_obj.compiler, nullptr);
ASSERT_NE(compiled_obj.runtime_program, nullptr); ASSERT_NE(compiled_obj.runtime_program, nullptr);
ASSERT_NE(compiled_obj.scope, nullptr); ASSERT_NE(compiled_obj.scope, nullptr);
ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty()); ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty());
ASSERT_NE(compiled_obj.launch_context, nullptr);
const auto& cached_obj = const auto& cached_obj =
cinn_compiler->Compile(compilation_key, input_tensors, target); cinn_compiler->Compile(compilation_key, input_tensors, target);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj), ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&cached_obj)); reinterpret_cast<std::uint64_t>(&cached_obj));
ASSERT_EQ(cached_obj.cached_index + 1, cinn_compiler->real_compiled_num());
const auto& ret_obj =
cinn_compiler->GetCompiledObject(cached_obj.cached_index);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&ret_obj));
}; };
// GPU Compilation // GPU Compilation
...@@ -295,4 +302,4 @@ USE_PASS(build_cinn_pass); ...@@ -295,4 +302,4 @@ USE_PASS(build_cinn_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_OP(mul); USE_OP(mul);
USE_OP(relu); USE_OP(relu);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
...@@ -532,6 +532,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { ...@@ -532,6 +532,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use XPU device since it's not compiled with XPU," "Paddle can't use XPU device since it's not compiled with XPU,"
"Please recompile or reinstall Paddle with XPU support.")); "Please recompile or reinstall Paddle with XPU support."));
#endif
} else if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(
new CustomDeviceUnsafeFastGarbageCollector(place, max_memory_size));
} else {
gc.reset(new CustomStreamGarbageCollector(place, max_memory_size));
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use custom device since it's not compiled with "
"CustomDevice,"
"Please recompile or reinstall Paddle with CustomDevice support."));
#endif #endif
} else if (platform::is_cpu_place(place)) { } else if (platform::is_cpu_place(place)) {
gc.reset(new CPUGarbageCollector(place, max_memory_size)); gc.reset(new CPUGarbageCollector(place, max_memory_size));
......
...@@ -186,8 +186,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() { ...@@ -186,8 +186,9 @@ KernelArgsNameMakerByOpProto::GetAttrsArgsNames() {
} }
KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() { KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
return KernelSignature(op_proto_->type(), GetInputArgsNames(), return KernelSignature(pten::TransToPtenKernelName(op_proto_->type()),
GetAttrsArgsNames(), GetOutputArgsNames()); GetInputArgsNames(), GetAttrsArgsNames(),
GetOutputArgsNames());
} }
std::once_flag kernel_sig_map_init_flag; std::once_flag kernel_sig_map_init_flag;
......
...@@ -91,7 +91,29 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place, ...@@ -91,7 +91,29 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
auto stream =
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream();
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
} else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_custom_place(dst_place)) {
auto stream =
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream();
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
} else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_custom_place(dst_place)) {
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data async from " << src_place << " to "
<< dst_place;
return;
}
auto stream =
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream();
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
}
#endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
else if (platform::is_xpu_place(src_place) && // NOLINT else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
...@@ -376,7 +398,8 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place, ...@@ -376,7 +398,8 @@ void TensorCopyImpl(const TENSOR& src, const platform::Place& dst_place,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx; const platform::DeviceContext* dev_ctx;
if (platform::is_gpu_place(dst_place) || platform::is_npu_place(dst_place) || if (platform::is_gpu_place(dst_place) || platform::is_npu_place(dst_place) ||
platform::is_mlu_place(dst_place)) { platform::is_mlu_place(dst_place) ||
platform::is_custom_place(dst_place)) {
dev_ctx = pool.Get(dst_place); dev_ctx = pool.Get(dst_place);
} else { } else {
dev_ctx = pool.Get(src.place()); dev_ctx = pool.Get(src.place());
...@@ -436,6 +459,26 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -436,6 +459,26 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
"Copy from %s to %s is not supported.", src_place, dst_place)); "Copy from %s to %s is not supported.", src_place, dst_place));
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { /* custom_device -> cpu*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
}
else if (platform::is_cpu_place(src_place) && // NOLINT
platform::is_custom_place(dst_place)) { /* cpu -> custom_device*/
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
}
else if (platform::is_custom_place(src_place) && // NOLINT
platform::is_custom_place(
dst_place)) { /* custom_device -> custom_device*/
if (src_ptr == dst_ptr) {
VLOG(3) << "Skip copy the same data sync from " << src_place << " to "
<< dst_place;
return;
}
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, nullptr);
}
#endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
else if (platform::is_xpu_place(src_place) && // NOLINT else if (platform::is_xpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
...@@ -664,6 +707,13 @@ class AnyVisitor : public boost::static_visitor<bool> { ...@@ -664,6 +707,13 @@ class AnyVisitor : public boost::static_visitor<bool> {
const platform::CUDAPinnedPlace& cpu) const { const platform::CUDAPinnedPlace& cpu) const {
return *out.data<bool>(); return *out.data<bool>();
} }
bool GetResult(const framework::Tensor& out,
const platform::CustomPlace& custom_dev) const {
PADDLE_THROW(platform::errors::Unimplemented("Not supported on place (%s) ",
custom_dev));
return false;
}
}; };
template <typename Predicate> template <typename Predicate>
...@@ -903,6 +953,11 @@ struct BothFalseVisitor : public boost::static_visitor<> { ...@@ -903,6 +953,11 @@ struct BothFalseVisitor : public boost::static_visitor<> {
out_ptr[i] = lhs && rhs; out_ptr[i] = lhs && rhs;
} }
} }
void VisitorImpl(const platform::CustomPlace& custom_dev) const {
PADDLE_THROW(
platform::errors::Unimplemented("CustomPlace is not supported"));
}
}; };
void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) { void TensorIsfinite(const framework::Tensor& tensor, framework::Tensor* out) {
...@@ -1036,6 +1091,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor, ...@@ -1036,6 +1091,29 @@ void TensorToStream(std::ostream& os, const Tensor& tensor,
#else #else
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU")); "NPUPlace is not supported when not compiled with NPU"));
#endif
} else if (platform::is_custom_place(tensor.place())) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto& custom_device_context =
static_cast<const platform::CustomDeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(), tensor.place(),
reinterpret_cast<const void*>(data), size_to_write,
custom_device_context.stream());
custom_device_context.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CustomPlace is not supported when not compiled with "
"CustomDevice"));
#endif #endif
} else { } else {
os.write(static_cast<const char*>(data_ptr), os.write(static_cast<const char*>(data_ptr),
...@@ -1093,10 +1171,11 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -1093,10 +1171,11 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
if (platform::is_gpu_place(dev_ctx.GetPlace()) || if (platform::is_gpu_place(dev_ctx.GetPlace()) ||
platform::is_xpu_place(dev_ctx.GetPlace()) || platform::is_xpu_place(dev_ctx.GetPlace()) ||
platform::is_mlu_place(dev_ctx.GetPlace()) || platform::is_mlu_place(dev_ctx.GetPlace()) ||
platform::is_npu_place(dev_ctx.GetPlace())) { platform::is_npu_place(dev_ctx.GetPlace()) ||
platform::is_custom_place(dev_ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_MLU) || \ defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_MLU) || \
defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
Tensor cpu_tensor; Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(shape)); cpu_tensor.Resize(framework::make_ddim(shape));
framework::VisitDataType( framework::VisitDataType(
...@@ -1105,7 +1184,8 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -1105,7 +1184,8 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
is.read(static_cast<char*>(buf), size); is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace(); auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
if (platform::is_npu_place(dev_ctx.GetPlace())) { if (platform::is_npu_place(dev_ctx.GetPlace()) ||
platform::is_custom_place(dev_ctx.GetPlace())) {
dev_ctx.Wait(); dev_ctx.Wait();
} }
#else #else
...@@ -1163,10 +1243,11 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -1163,10 +1243,11 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
if (platform::is_gpu_place(dev_ctx.GetPlace()) || if (platform::is_gpu_place(dev_ctx.GetPlace()) ||
platform::is_xpu_place(dev_ctx.GetPlace()) || platform::is_xpu_place(dev_ctx.GetPlace()) ||
platform::is_mlu_place(dev_ctx.GetPlace()) || platform::is_mlu_place(dev_ctx.GetPlace()) ||
platform::is_npu_place(dev_ctx.GetPlace())) { platform::is_npu_place(dev_ctx.GetPlace()) ||
platform::is_custom_place(dev_ctx.GetPlace())) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_MLU) || \ defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_MLU) || \
defined(PADDLE_WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_CUSTOM_DEVICE)
Tensor cpu_tensor; Tensor cpu_tensor;
cpu_tensor.Resize(framework::make_ddim(dims)); cpu_tensor.Resize(framework::make_ddim(dims));
framework::VisitDataType( framework::VisitDataType(
...@@ -1175,7 +1256,8 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -1175,7 +1256,8 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
is.read(static_cast<char*>(buf), size); is.read(static_cast<char*>(buf), size);
auto dst_place = dev_ctx.GetPlace(); auto dst_place = dev_ctx.GetPlace();
framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor); framework::TensorCopy(cpu_tensor, dst_place, dev_ctx, tensor);
if (platform::is_npu_place(dev_ctx.GetPlace())) { if (platform::is_npu_place(dev_ctx.GetPlace()) ||
platform::is_custom_place(dev_ctx.GetPlace())) {
dev_ctx.Wait(); dev_ctx.Wait();
} }
#else #else
...@@ -1188,9 +1270,12 @@ void TensorFromStream(std::istream& is, Tensor* tensor, ...@@ -1188,9 +1270,12 @@ void TensorFromStream(std::istream& is, Tensor* tensor,
} else if (platform::is_mlu_place(dev_ctx.GetPlace())) { } else if (platform::is_mlu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"MLUPlace is not supported when not compiled with MLU")); "MLUPlace is not supported when not compiled with MLU"));
} else { } else if (platform::is_npu_place(dev_ctx.GetPlace())) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported when not compiled with NPU")); "NPUPlace is not supported when not compiled with NPU"));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"CutomPlace is not supported when not compiled with CustomDevice"));
} }
#endif #endif
} else { } else {
......
...@@ -180,6 +180,17 @@ void TensorFromArray(const T* src, const size_t& array_size, ...@@ -180,6 +180,17 @@ void TensorFromArray(const T* src, const size_t& array_size,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream()); reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(dst_place)) { // NOLINT
memory::Copy(
dst_place, dst_ptr, src_place, src_ptr, size,
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream());
}
#endif
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"TensorFromArray on %s is not supported.", dst_place));
}
} }
template <typename T> template <typename T>
...@@ -241,6 +252,17 @@ void TensorFromVector(const std::vector<T>& src, ...@@ -241,6 +252,17 @@ void TensorFromVector(const std::vector<T>& src,
reinterpret_cast<const platform::MLUDeviceContext&>(ctx).stream()); reinterpret_cast<const platform::MLUDeviceContext&>(ctx).stream());
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(dst_place)) { // NOLINT
memory::Copy(
dst_place, dst_ptr, src_place, src_ptr, size,
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream());
}
#endif
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"TensorFromVector on %s is not supported.", dst_place));
}
} }
// The fully specialized function should be inline to avoid // The fully specialized function should be inline to avoid
...@@ -300,6 +322,17 @@ inline void TensorFromVector(const std::vector<bool>& src, ...@@ -300,6 +322,17 @@ inline void TensorFromVector(const std::vector<bool>& src,
reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream()); reinterpret_cast<const platform::NPUDeviceContext&>(ctx).stream());
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEICE
else if (platform::is_custom_place(dst_place)) { // NOLINT
auto stream =
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream();
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
}
#endif
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"TensorFromVector on %s is not supported.", dst_place));
}
delete[] array; delete[] array;
} }
...@@ -369,6 +402,15 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx, ...@@ -369,6 +402,15 @@ void TensorToVector(const Tensor& src, const platform::DeviceContext& ctx,
reinterpret_cast<const platform::MLUDeviceContext&>(ctx).stream()); reinterpret_cast<const platform::MLUDeviceContext&>(ctx).stream());
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src.place())) { // NOLINT
memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size, nullptr);
}
#endif
else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented(
"TensorToVector on %s is not supported.", src.place()));
}
} }
template <> template <>
...@@ -410,6 +452,11 @@ inline void TensorToVector(const Tensor& src, ...@@ -410,6 +452,11 @@ inline void TensorToVector(const Tensor& src,
dst_place, dst_ptr, src.place(), src_ptr, size, dst_place, dst_ptr, src.place(), src_ptr, size,
reinterpret_cast<const platform::MLUDeviceContext&>(ctx).stream()); reinterpret_cast<const platform::MLUDeviceContext&>(ctx).stream());
} }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
else if (platform::is_custom_place(src.place())) { // NOLINT
memory::Copy(dst_place, dst_ptr, src.place(), src_ptr, size, nullptr);
}
#endif #endif
for (unsigned int i = 0; i < src.numel(); i++) { for (unsigned int i = 0; i < src.numel(); i++) {
(*dst)[i] = static_cast<bool>(array[i]); (*dst)[i] = static_cast<bool>(array[i]);
......
...@@ -44,9 +44,9 @@ if(WITH_GLOO) ...@@ -44,9 +44,9 @@ if(WITH_GLOO)
endif() endif()
if(NOT WITH_ASCEND_CL) if(NOT WITH_ASCEND_CL)
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function) cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function pten_tensor)
else() else()
cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function npu_op_runner) cc_library(gradient_accumulator SRCS gradient_accumulator.cc DEPS blas operator lod_tensor selected_rows_utils selected_rows_functor var_type_traits layer math_function npu_op_runner pten_tensor)
endif() endif()
add_subdirectory(tests) add_subdirectory(tests)
...@@ -340,8 +340,8 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type, ...@@ -340,8 +340,8 @@ NameVarMap<VarType> AutoCastInputs(const std::string& op_type,
} }
template NameVarMap<VarBase> AutoCastInputs<VarBase>( template NameVarMap<VarBase> AutoCastInputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins); const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerTensor> AutoCastInputs<egr::EagerTensor>( template NameVarMap<egr::EagerVariable> AutoCastInputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerTensor>& ins); const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
template <typename VarType> template <typename VarType>
NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type, NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
const NameVarMap<VarType>& ins) { const NameVarMap<VarType>& ins) {
...@@ -384,7 +384,7 @@ NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type, ...@@ -384,7 +384,7 @@ NameVarMap<VarType> CastPureFp16Inputs(const std::string& op_type,
} }
template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>( template NameVarMap<VarBase> CastPureFp16Inputs<VarBase>(
const std::string& op_type, const NameVarMap<VarBase>& ins); const std::string& op_type, const NameVarMap<VarBase>& ins);
template NameVarMap<egr::EagerTensor> CastPureFp16Inputs<egr::EagerTensor>( template NameVarMap<egr::EagerVariable> CastPureFp16Inputs<egr::EagerVariable>(
const std::string& op_type, const NameVarMap<egr::EagerTensor>& ins); const std::string& op_type, const NameVarMap<egr::EagerVariable>& ins);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -35,6 +35,9 @@ ...@@ -35,6 +35,9 @@
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#endif #endif
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -180,6 +183,12 @@ class TensorAddFunctor : public boost::static_visitor<> { ...@@ -180,6 +183,12 @@ class TensorAddFunctor : public boost::static_visitor<> {
"is not supported in imperative mode", "is not supported in imperative mode",
place)); place));
} }
void operator()(const platform::CustomPlace& place) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"Gradient accumulation on place (%s) "
"is not supported in imperative mode",
place));
}
private: private:
int64_t numel_; int64_t numel_;
...@@ -243,6 +252,23 @@ TType& GetInnerTensor(const paddle::experimental::Tensor& src) { ...@@ -243,6 +252,23 @@ TType& GetInnerTensor(const paddle::experimental::Tensor& src) {
return *src_tensor; return *src_tensor;
} }
template <typename TType>
TType* GetEmptyInnerTensor(paddle::experimental::Tensor* dst) {
PADDLE_ENFORCE_EQ(
dst->defined(), false,
platform::errors::Fatal(
"The underlying Tensor implementation should be nullptr"));
dst->set_impl(std::make_shared<TType>());
auto* dst_tensor = static_cast<TType*>(dst->impl().get());
return dst_tensor;
}
template <typename TType>
TType* GetEmptyInnerTensor(paddle::imperative::VariableWrapper* dst) {
auto* dst_tensor = dst->MutableVar()->GetMutable<TType>();
return dst_tensor;
}
template <typename VarType> template <typename VarType>
void TensorAdd(const VarType& src, VarType* dst) { void TensorAdd(const VarType& src, VarType* dst) {
pten::DenseTensor* dst_tensor = GetInnerMutableTensor<pten::DenseTensor>(dst); pten::DenseTensor* dst_tensor = GetInnerMutableTensor<pten::DenseTensor>(dst);
...@@ -314,7 +340,14 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -314,7 +340,14 @@ void TensorAdd(const VarType& src, VarType* dst) {
return; return;
} }
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (platform::is_custom_place(place)) {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
}
#endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(place)) { if (platform::is_xpu_place(place)) {
if (data_type == framework::DataTypeTrait<float>::DataType()) { if (data_type == framework::DataTypeTrait<float>::DataType()) {
...@@ -332,6 +365,35 @@ void TensorAdd(const VarType& src, VarType* dst) { ...@@ -332,6 +365,35 @@ void TensorAdd(const VarType& src, VarType* dst) {
} }
#endif #endif
#ifdef PADDLE_WITH_MLU
if (platform::is_mlu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::DeviceContext* ctx = pool.Get(place);
auto dev_ctx = dynamic_cast<platform::MLUDeviceContext*>(ctx);
if (data_type == framework::DataTypeTrait<float>::DataType()) {
dst_tensor->mutable_data<float>(place);
} else if (data_type ==
framework::DataTypeTrait<platform::float16>::DataType()) {
dst_tensor->mutable_data<platform::float16>(place);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Gradient accumulation of data type (%s) on place (%s) is not "
"supported in imperative mode",
framework::DataTypeToString(data_type), place));
}
static const float alpha = 1.f;
static const float beta = 1.f;
operators::MLUCnnlTensorDesc src_tensor_desc(src_tensor);
operators::MLUCnnlTensorDesc dst_tensor_desc(*dst_tensor);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlAssignAdd(
dev_ctx->cnnl_handle(), static_cast<void*>(&alpha),
src_tensor_desc.get(), operators::GetBasePtr(&src_tensor), nullptr, 0,
static_cast<void*>(&beta), dst_tensor_desc.get(),
operators::GetBasePtr(dst_tensor)));
return;
}
#endif
PADDLE_TENSOR_ADD(float); PADDLE_TENSOR_ADD(float);
#ifndef PADDLE_WITH_XPU #ifndef PADDLE_WITH_XPU
...@@ -473,13 +535,14 @@ template void SelectedRowsAddTensor( ...@@ -473,13 +535,14 @@ template void SelectedRowsAddTensor(
// Note(chenweihang): when two selected rows need to be added, // Note(chenweihang): when two selected rows need to be added,
// adding one to another is not equal to merging two selected rows // adding one to another is not equal to merging two selected rows
// to one then add it to a empty selected rows, the after is correct // to one then add it to a empty selected rows, the after is correct
// Note(chenweihang): when two selected rows need to be added, template <typename ReturnVarType, typename VarType>
// adding one to another is not equal to merging two selected rows std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
// to one then add it to a empty selected rows, the after is correct const VarType& src2) {
std::shared_ptr<VariableWrapper> SelectedRowsMerge( const pten::SelectedRows& src_selected_rows1 =
const framework::Variable& src1, const framework::Variable& src2) { GetInnerTensor<pten::SelectedRows>(src1);
auto& src_selected_rows1 = src1.Get<pten::SelectedRows>(); const pten::SelectedRows& src_selected_rows2 =
auto& src_selected_rows2 = src2.Get<pten::SelectedRows>(); GetInnerTensor<pten::SelectedRows>(src2);
auto place = src_selected_rows1.value().place(); auto place = src_selected_rows1.value().place();
auto data_type = auto data_type =
framework::TransToProtoVarType(src_selected_rows1.value().dtype()); framework::TransToProtoVarType(src_selected_rows1.value().dtype());
...@@ -488,9 +551,10 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge( ...@@ -488,9 +551,10 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge(
std::vector<const pten::SelectedRows*> src_selected_rows; std::vector<const pten::SelectedRows*> src_selected_rows;
src_selected_rows.emplace_back(&src_selected_rows1); src_selected_rows.emplace_back(&src_selected_rows1);
src_selected_rows.emplace_back(&src_selected_rows2); src_selected_rows.emplace_back(&src_selected_rows2);
auto dst_var = std::make_shared<VariableWrapper>("Temp");
auto* dst_selected_rows = auto dst_var = std::make_shared<ReturnVarType>("Temp");
dst_var->MutableVar()->GetMutable<pten::SelectedRows>(); pten::SelectedRows* dst_selected_rows =
GetEmptyInnerTensor<pten::SelectedRows>(dst_var.get());
#define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type) \ #define PADDLE_SELECTED_ROWS_ADD(dev_ctx_type, cpp_type) \
if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \ if (data_type == framework::DataTypeTrait<cpp_type>::DataType()) { \
...@@ -515,12 +579,17 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge( ...@@ -515,12 +579,17 @@ std::shared_ptr<VariableWrapper> SelectedRowsMerge(
#endif #endif
#undef PADDLE_SELECTED_ROWS_ADD #undef PADDLE_SELECTED_ROWS_ADD
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Not supported data type %s for SelectedRowsMerge", "Not supported data type %s for SelectedRowsMerge",
framework::DataTypeToString(data_type))); framework::DataTypeToString(data_type)));
} }
template std::shared_ptr<paddle::experimental::Tensor> SelectedRowsMerge(
const paddle::experimental::Tensor& src1,
const paddle::experimental::Tensor& src2);
template std::shared_ptr<paddle::imperative::VariableWrapper> SelectedRowsMerge(
const framework::Variable& src1, const framework::Variable& src2);
void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var, void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
VariableWrapper* dst_var, bool unchange_input) { VariableWrapper* dst_var, bool unchange_input) {
auto& src = var->Var(); auto& src = var->Var();
...@@ -547,7 +616,7 @@ void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var, ...@@ -547,7 +616,7 @@ void VariableWrapperAdd(std::shared_ptr<VariableWrapper> var,
*dst = std::move(*(var->MutableVar())); *dst = std::move(*(var->MutableVar()));
} }
} else if (src.IsType<pten::SelectedRows>()) { } else if (src.IsType<pten::SelectedRows>()) {
auto temp = SelectedRowsMerge(src, *dst); auto temp = SelectedRowsMerge<VariableWrapper>(src, *dst);
*dst = std::move(*(temp->MutableVar())); *dst = std::move(*(temp->MutableVar()));
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -603,7 +672,7 @@ void GradientAccumulator::AccumulateGrad() { ...@@ -603,7 +672,7 @@ void GradientAccumulator::AccumulateGrad() {
SelectedRowsAddToTensor(*dst, src); SelectedRowsAddToTensor(*dst, src);
*dst = std::move(*src); *dst = std::move(*src);
} else if (src->IsType<pten::SelectedRows>()) { } else if (src->IsType<pten::SelectedRows>()) {
auto temp = SelectedRowsMerge(*src, *dst); auto temp = SelectedRowsMerge<VariableWrapper>(*src, *dst);
*dst = std::move(*(temp->MutableVar())); *dst = std::move(*(temp->MutableVar()));
} }
} else { } else {
......
...@@ -17,10 +17,10 @@ ...@@ -17,10 +17,10 @@
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/imperative/hooks.h" #include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layer.h"
#include "paddle/pten/api/include/tensor.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -164,6 +164,10 @@ class SortedGradientAccumulator : public GradientAccumulator { ...@@ -164,6 +164,10 @@ class SortedGradientAccumulator : public GradientAccumulator {
std::vector<SavedVarInfo> tmp_grad_vars_; std::vector<SavedVarInfo> tmp_grad_vars_;
}; };
template <typename ReturnVarType, typename VarType>
std::shared_ptr<ReturnVarType> SelectedRowsMerge(const VarType& src1,
const VarType& src2);
template <typename VarType> template <typename VarType>
void SelectedRowsAddToTensor(const VarType& src, VarType* dst); void SelectedRowsAddToTensor(const VarType& src, VarType* dst);
......
...@@ -177,9 +177,9 @@ std::string LayerDebugString(const std::string& op_type, ...@@ -177,9 +177,9 @@ std::string LayerDebugString(const std::string& op_type,
} }
std::string LayerDebugString(const std::string& op_type, std::string LayerDebugString(const std::string& op_type,
const NameVarMap<egr::EagerTensor>& ins, const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs) { const NameVarMap<egr::EagerVariable>& outs) {
return LayerDebugStringImpl<egr::EagerTensor>(op_type, ins, outs); return LayerDebugStringImpl<egr::EagerVariable>(op_type, ins, outs);
} }
template <typename VarType> template <typename VarType>
...@@ -194,11 +194,16 @@ static void SetForwardDataTypeOfGradVars(const NameVarMap<VarType>& outs) { ...@@ -194,11 +194,16 @@ static void SetForwardDataTypeOfGradVars(const NameVarMap<VarType>& outs) {
} }
} }
template <> template <>
void SetForwardDataTypeOfGradVars<egr::EagerTensor>( void SetForwardDataTypeOfGradVars<egr::EagerVariable>(
const NameVarMap<egr::EagerTensor>& outs) { const NameVarMap<egr::EagerVariable>& outs) {
// In eager mode we don't need this. // In eager mode we don't need this.
} }
void TestSetForwardDataTypeOfGradVarsEager(
const NameVarMap<egr::EagerVariable>& outs) {
SetForwardDataTypeOfGradVars<egr::EagerVariable>(outs);
}
VarBase::VarBase(const std::shared_ptr<VariableWrapper>& var) VarBase::VarBase(const std::shared_ptr<VariableWrapper>& var)
: var_(var), grad_node_(var->GetGradNode()) { : var_(var), grad_node_(var->GetGradNode()) {
if (auto grad_var = var_->GetGradVar()) { if (auto grad_var = var_->GetGradVar()) {
...@@ -528,12 +533,12 @@ void OpBase::Run(const framework::OperatorBase& op, ...@@ -528,12 +533,12 @@ void OpBase::Run(const framework::OperatorBase& op,
} }
void OpBase::Run(const framework::OperatorBase& op, void OpBase::Run(const framework::OperatorBase& op,
const NameVarMap<egr::EagerTensor>& ins, const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs, const NameVarMap<egr::EagerVariable>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place) { const platform::Place& place) {
OpBaseRunImpl<egr::EagerTensor>(op, ins, outs, attrs, default_attrs, place); OpBaseRunImpl<egr::EagerVariable>(op, ins, outs, attrs, default_attrs, place);
} }
void ClearNoNeedBufferInputs(OpBase* op) { void ClearNoNeedBufferInputs(OpBase* op) {
......
...@@ -185,8 +185,8 @@ class OpBase { ...@@ -185,8 +185,8 @@ class OpBase {
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place); const platform::Place& place);
static void Run(const framework::OperatorBase& op, static void Run(const framework::OperatorBase& op,
const NameVarMap<egr::EagerTensor>& ins, const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs, const NameVarMap<egr::EagerVariable>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs, const framework::AttributeMap& default_attrs,
const platform::Place& place); const platform::Place& place);
......
...@@ -89,11 +89,16 @@ void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) { ...@@ -89,11 +89,16 @@ void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
} }
template <> template <>
void HandleComplexGradToRealGrad<egr::EagerTensor>( void HandleComplexGradToRealGrad<egr::EagerVariable>(
const NameVarMap<egr::EagerTensor>& outs) { const NameVarMap<egr::EagerVariable>& outs) {
// TODO(jiabin): Support Complex here. // TODO(jiabin): Support Complex here.
} }
void TestHandleComplexGradToRealGradEager(
const NameVarMap<egr::EagerVariable>& outs) {
HandleComplexGradToRealGrad<egr::EagerVariable>(outs);
}
PreparedOp::PreparedOp(const framework::OperatorBase& op, PreparedOp::PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx, const framework::RuntimeContext& ctx,
const framework::OpKernelType& kernel_type, const framework::OpKernelType& kernel_type,
...@@ -278,6 +283,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -278,6 +283,16 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
expected_kernel_key.place_ = platform::CPUPlace(); expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key); kernel_iter = kernels.find(expected_kernel_key);
} }
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (kernel_iter == kernels.end() &&
paddle::platform::is_custom_place(expected_kernel_key.place_)) {
VLOG(3) << "missing " << place.GetDeviceType() << " kernel: " << op.Type()
<< ", expected_kernel_key:" << expected_kernel_key
<< ", fallbacking to CPU one!";
expected_kernel_key.place_ = platform::CPUPlace();
kernel_iter = kernels.find(expected_kernel_key);
}
#endif #endif
// TODO(jiabin): Add operator.cc's line 1000 part back when we need that // TODO(jiabin): Add operator.cc's line 1000 part back when we need that
// case // case
...@@ -312,14 +327,14 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins, ...@@ -312,14 +327,14 @@ PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
default_attrs); default_attrs);
} }
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerTensor>& ins, PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs, const NameVarMap<egr::EagerVariable>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
return PrepareImpl<egr::EagerTensor>(ins, outs, op, place, attrs, return PrepareImpl<egr::EagerVariable>(ins, outs, op, place, attrs,
default_attrs); default_attrs);
} }
template <typename VarType> template <typename VarType>
static void PreparedOpRunImpl( static void PreparedOpRunImpl(
...@@ -451,18 +466,18 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -451,18 +466,18 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
} }
} }
void PreparedOp::Run(const NameVarMap<egr::EagerTensor>& ins, void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs, const NameVarMap<egr::EagerVariable>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) { if (run_pten_kernel_) {
PreparedOpRunPtImpl<egr::EagerTensor>( PreparedOpRunPtImpl<egr::EagerVariable>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins, op_, kernel_type_, pt_kernel_signature_, pt_kernel_, dev_ctx_, ins,
outs, attrs, default_attrs); outs, attrs, default_attrs);
} else { } else {
PreparedOpRunImpl<egr::EagerTensor>(op_, ctx_, kernel_type_, func_, PreparedOpRunImpl<egr::EagerVariable>(op_, ctx_, kernel_type_, func_,
dev_ctx_, ins, outs, attrs, dev_ctx_, ins, outs, attrs,
default_attrs); default_attrs);
} }
} }
......
...@@ -63,8 +63,8 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) { ...@@ -63,8 +63,8 @@ void SetForwardDataTypeOfGradVar<VarBase>(const std::shared_ptr<VarBase>& var) {
} }
template <> template <>
void SetForwardDataTypeOfGradVar<egr::EagerTensor>( void SetForwardDataTypeOfGradVar<egr::EagerVariable>(
const std::shared_ptr<egr::EagerTensor>& var) { const std::shared_ptr<egr::EagerVariable>& var) {
VLOG(10) << "Var in Eager dose not support SetForwardDataTypeOfGradVar: " VLOG(10) << "Var in Eager dose not support SetForwardDataTypeOfGradVar: "
<< var->name(); << var->name();
// TODO(jiabin): SetForwardDataType of Grad var is not supported yet in // TODO(jiabin): SetForwardDataType of Grad var is not supported yet in
...@@ -171,8 +171,8 @@ class PreparedOp { ...@@ -171,8 +171,8 @@ class PreparedOp {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs);
static PreparedOp Prepare(const NameVarMap<egr::EagerTensor>& ins, static PreparedOp Prepare(const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs, const NameVarMap<egr::EagerVariable>& outs,
const framework::OperatorWithKernel& op, const framework::OperatorWithKernel& op,
const platform::Place& place, const platform::Place& place,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
...@@ -187,8 +187,8 @@ class PreparedOp { ...@@ -187,8 +187,8 @@ class PreparedOp {
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs);
void Run(const NameVarMap<egr::EagerTensor>& ins, void Run(const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs, const NameVarMap<egr::EagerVariable>& outs,
const framework::AttributeMap& attrs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs); const framework::AttributeMap& default_attrs);
...@@ -270,26 +270,26 @@ void BuildDygraphPtenKernelContext( ...@@ -270,26 +270,26 @@ void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr); kernel_ctx->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1; auto end_idx = start_idx + 1;
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} else { continue;
auto ins_vector = it->second; }
size_t end_idx = start_idx + ins_vector.size(); auto ins_vector = it->second;
size_t end_idx = start_idx + ins_vector.size();
for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
const pten::TensorBase* tensor_in = nullptr; for (size_t offset = 0; offset < ins_vector.size(); ++offset) {
auto& var = ins_vector[offset]->Var(); const pten::TensorBase* tensor_in = nullptr;
if (var.template IsType<pten::DenseTensor>()) { auto& var = ins_vector[offset]->Var();
tensor_in = &(var.template Get<pten::DenseTensor>()); if (var.template IsType<pten::DenseTensor>()) {
} else if (var.template IsType<pten::SelectedRows>()) { tensor_in = &(var.template Get<pten::DenseTensor>());
tensor_in = &(var.template Get<pten::SelectedRows>()); } else if (var.template IsType<pten::SelectedRows>()) {
} else { tensor_in = &(var.template Get<pten::SelectedRows>());
PADDLE_THROW(platform::errors::Unimplemented( } else {
"Unsupported input `%s` type when call pt kernel.", PADDLE_THROW(platform::errors::Unimplemented(
framework::ToTypeName(var.Type()))); "Unsupported input `%s` type when call pt kernel.",
} framework::ToTypeName(var.Type())));
kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} }
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i); kernel_ctx->EmplaceBackInputWithoutSetRange(tensor_in);
} }
kernel_ctx->AssignInputRange(std::make_pair(start_idx, end_idx), i);
} }
for (size_t i = 0; i < output_names.size(); ++i) { for (size_t i = 0; i < output_names.size(); ++i) {
...@@ -421,6 +421,8 @@ void BuildDygraphPtenKernelContext( ...@@ -421,6 +421,8 @@ void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(int64_t))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(int64_t, attr));
} else if (attr_defs[i].type_index == } else if (attr_defs[i].type_index ==
std::type_index(typeid(std::string))) { std::type_index(typeid(std::string))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr)); kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::string, attr));
...@@ -466,8 +468,7 @@ void PreparePtenData(const pten::Kernel& pt_kernel, ...@@ -466,8 +468,7 @@ void PreparePtenData(const pten::Kernel& pt_kernel,
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
auto& in_def = input_defs.at(i); auto& in_def = input_defs.at(i);
auto it = ins.find(input_names[i]); if (ins.find(input_names[i]) == ins.end()) {
if (it == ins.end()) {
continue; continue;
} }
auto& ins_vector = ins.at(input_names[i]); auto& ins_vector = ins.at(input_names[i]);
......
...@@ -12,7 +12,7 @@ else() ...@@ -12,7 +12,7 @@ else()
endif(WIN32) endif(WIN32)
cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows_utils selected_rows_functor gradient_accumulator math_function) cc_test(test_gradient_accmulator SRCS test_gradient_accmulator.cc DEPS memcpy selected_rows_utils selected_rows_functor gradient_accumulator math_function pten_tensor pten_api pten_api_utils)
cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy) cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry variable_helper mul_op memcpy)
cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place) cc_test(test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place)
cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy) cc_test(test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy)
......
...@@ -31,8 +31,8 @@ ...@@ -31,8 +31,8 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
extern std::string LayerDebugString(const std::string& op_type, extern std::string LayerDebugString(const std::string& op_type,
const NameVarMap<egr::EagerTensor>& ins, const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs); const NameVarMap<egr::EagerVariable>& outs);
extern std::shared_ptr<GradOpNode> CreateGradOpNode( extern std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameTensorMap& ins, const framework::OperatorBase& op, const NameTensorMap& ins,
...@@ -41,20 +41,21 @@ extern std::shared_ptr<GradOpNode> CreateGradOpNode( ...@@ -41,20 +41,21 @@ extern std::shared_ptr<GradOpNode> CreateGradOpNode(
const std::map<std::string, std::string>& inplace_map); const std::map<std::string, std::string>& inplace_map);
TEST(test_eager, eager_debug) { TEST(test_eager, eager_debug) {
std::shared_ptr<egr::EagerTensor> x_in(new egr::EagerTensor("x_in")); std::shared_ptr<egr::EagerVariable> x_in(new egr::EagerVariable("x_in"));
std::shared_ptr<egr::EagerTensor> y_in(new egr::EagerTensor("y_in")); std::shared_ptr<egr::EagerVariable> y_in(new egr::EagerVariable("y_in"));
std::shared_ptr<egr::EagerTensor> vout(new egr::EagerTensor("vout")); std::shared_ptr<egr::EagerVariable> vout(new egr::EagerVariable("vout"));
imperative::NameVarMap<egr::EagerTensor> ins = {{"X", {x_in}}, {"Y", {y_in}}}; imperative::NameVarMap<egr::EagerVariable> ins = {{"X", {x_in}},
imperative::NameVarMap<egr::EagerTensor> outs = {{"Out", {vout}}}; {"Y", {y_in}}};
imperative::NameVarMap<egr::EagerVariable> outs = {{"Out", {vout}}};
LayerDebugString("mul", ins, outs); LayerDebugString("mul", ins, outs);
} }
TEST(test_create_node, eager_node) { TEST(test_create_node, eager_node) {
auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false); auto op = framework::OpRegistry::CreateOp("mul", {}, {}, {}, false);
framework::Scope scope; framework::Scope scope;
auto ctx = framework::RuntimeContext({}, {}); auto ctx = framework::RuntimeContext({}, {});
imperative::NameVarMap<egr::EagerTensor> ins = {{"X", {nullptr}}, imperative::NameVarMap<egr::EagerVariable> ins = {{"X", {nullptr}},
{"Y", {nullptr}}}; {"Y", {nullptr}}};
imperative::NameVarMap<egr::EagerTensor> outs = {{"Out", {nullptr}}}; imperative::NameVarMap<egr::EagerVariable> outs = {{"Out", {nullptr}}};
CreateGradOpNode((*op.get()), ins, outs, framework::AttributeMap{}, CreateGradOpNode((*op.get()), ins, outs, framework::AttributeMap{},
framework::AttributeMap{}, platform::CPUPlace(), {}); framework::AttributeMap{}, platform::CPUPlace(), {});
} }
...@@ -72,26 +73,26 @@ TEST(test_var_helper, eager_var_helper) { ...@@ -72,26 +73,26 @@ TEST(test_var_helper, eager_var_helper) {
ASSERT_ANY_THROW( ASSERT_ANY_THROW(
InitializeVariable(&var8, paddle::framework::proto::VarType::FP64)); InitializeVariable(&var8, paddle::framework::proto::VarType::FP64));
auto egr_tensor = std::make_shared<egr::EagerTensor>(); auto egr_tensor = std::make_shared<egr::EagerVariable>();
auto egr_tensor2 = std::make_shared<egr::EagerTensor>(); auto egr_tensor2 = std::make_shared<egr::EagerVariable>();
egr_tensor->MutableVar() egr_tensor->MutableVar()
->GetMutable<pten::SelectedRows>() ->GetMutable<pten::SelectedRows>()
->mutable_value() ->mutable_value()
->mutable_data<float>(platform::CPUPlace()); ->mutable_data<float>(platform::CPUPlace());
egr_tensor2->MutableVar()->GetMutable<framework::LoDRankTable>(); egr_tensor2->MutableVar()->GetMutable<framework::LoDRankTable>();
VLOG(6) << "egr_tensor create with "; VLOG(6) << "egr_tensor create with ";
ASSERT_TRUE(platform::is_cpu_place(GetPlace<egr::EagerTensor>(egr_tensor))); ASSERT_TRUE(platform::is_cpu_place(GetPlace<egr::EagerVariable>(egr_tensor)));
ASSERT_TRUE(GetDataType<egr::EagerTensor>(egr_tensor) == ASSERT_TRUE(GetDataType<egr::EagerVariable>(egr_tensor) ==
framework::proto::VarType::FP32); framework::proto::VarType::FP32);
GetCachedValue<egr::EagerTensor>( GetCachedValue<egr::EagerVariable>(
egr_tensor, framework::OpKernelType(framework::proto::VarType::FP32, egr_tensor, framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace())); platform::CPUPlace()));
SetCachedValue<egr::EagerTensor>( SetCachedValue<egr::EagerVariable>(
egr_tensor, framework::OpKernelType(framework::proto::VarType::FP32, egr_tensor, framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace()), platform::CPUPlace()),
egr_tensor2); egr_tensor2);
ASSERT_ANY_THROW(GetPlace<egr::EagerTensor>(egr_tensor2)); ASSERT_ANY_THROW(GetPlace<egr::EagerVariable>(egr_tensor2));
ASSERT_ANY_THROW(SetType<egr::EagerTensor>( ASSERT_ANY_THROW(SetType<egr::EagerVariable>(
egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY)); egr_tensor, paddle::framework::proto::VarType::LOD_TENSOR_ARRAY));
} }
} // namespace imperative } // namespace imperative
......
...@@ -29,6 +29,57 @@ namespace framework = paddle::framework; ...@@ -29,6 +29,57 @@ namespace framework = paddle::framework;
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
TEST(Test__SelectedRowsMerge_Test, SelectedRowsMerge) {
pten::CPUPlace cpu;
std::vector<int64_t> rows{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int64_t table_size = 10;
int64_t embedding_width = 10;
auto sr1 = std::make_shared<pten::SelectedRows>(rows, table_size);
auto sr2 = std::make_shared<pten::SelectedRows>(rows, table_size);
// initialize a sparse table 1
sr1->mutable_value()->Resize(
pten::framework::make_ddim({table_size, embedding_width}));
auto* data_sr1 = sr1->mutable_value()->mutable_data<float>(cpu);
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
data_sr1[i * embedding_width + j] = static_cast<float>(i);
}
}
// initialize a sparse table 2
sr2->mutable_value()->Resize(
pten::framework::make_ddim({table_size, embedding_width}));
auto* data_sr2 = sr2->mutable_value()->mutable_data<float>(cpu);
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
data_sr2[i * embedding_width + j] = static_cast<float>(i);
}
}
// new 2 pten::Tensor
paddle::experimental::Tensor t1(sr1);
paddle::experimental::Tensor t2(sr2);
// call SelectedRowsMerge
auto new_buffer =
paddle::imperative::SelectedRowsMerge<paddle::experimental::Tensor>(t1,
t2);
auto* new_buffer_tensor =
static_cast<pten::SelectedRows*>(new_buffer->impl().get());
auto* new_buffer_data_sr1 =
new_buffer_tensor->mutable_value()->mutable_data<float>(cpu);
// verify the MergeAdd result
for (int64_t i = 0; i < table_size; ++i) {
for (int64_t j = 0; j < embedding_width; ++j) {
EXPECT_EQ(new_buffer_data_sr1[i * embedding_width + j],
(static_cast<float>(i) + static_cast<float>(i)));
}
}
}
template <typename Place1, typename Place2, typename T> template <typename Place1, typename Place2, typename T>
int TensorddTest(Place1 place1, Place2 place2, T t1, T t2) { int TensorddTest(Place1 place1, Place2 place2, T t1, T t2) {
framework::Variable var1; framework::Variable var1;
......
...@@ -265,5 +265,5 @@ TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) { ...@@ -265,5 +265,5 @@ TEST(TestHooks, TestGradVarLeafBackwardHookWithSortedGradAccmulated) {
USE_OP(mul); USE_OP(mul);
USE_OP(mul_grad); USE_OP(mul_grad);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(elementwise_add_grad); USE_OP_ITSELF(elementwise_add_grad);
...@@ -39,6 +39,8 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>; ...@@ -39,6 +39,8 @@ using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using var_pair = std::pair<std::string, vb_vector>; using var_pair = std::pair<std::string, vb_vector>;
extern void TestSetForwardDataTypeOfGradVarsEager(
const NameVarMap<egr::EagerVariable>& outs);
template <typename VarType> template <typename VarType>
class TestRuntimeInferVarTypeContext class TestRuntimeInferVarTypeContext
: public RuntimeInferVarTypeContext<VarType> { : public RuntimeInferVarTypeContext<VarType> {
...@@ -406,6 +408,11 @@ TEST(test_layer, test_inner_op_not_inited) { ...@@ -406,6 +408,11 @@ TEST(test_layer, test_inner_op_not_inited) {
ASSERT_THROW(op.CheckAttrs(), platform::EnforceNotMet); ASSERT_THROW(op.CheckAttrs(), platform::EnforceNotMet);
} }
TEST(test_layer, test_eager) {
imperative::NameTensorMap ins = {};
TestSetForwardDataTypeOfGradVarsEager(ins);
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
......
...@@ -32,6 +32,9 @@ namespace framework = paddle::framework; ...@@ -32,6 +32,9 @@ namespace framework = paddle::framework;
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
extern void TestHandleComplexGradToRealGradEager(
const NameVarMap<egr::EagerVariable>& outs);
static framework::VariableNameMap CreateVarNameMap( static framework::VariableNameMap CreateVarNameMap(
const framework::OpInfo& op_info, const std::string& op_type, const framework::OpInfo& op_info, const std::string& op_type,
const NameVarBaseMap& varbase_map, bool is_input) { const NameVarBaseMap& varbase_map, bool is_input) {
...@@ -209,6 +212,11 @@ TEST(test_prepare_op, test_prepare_data_same_place) { ...@@ -209,6 +212,11 @@ TEST(test_prepare_op, test_prepare_data_same_place) {
TestPrepareDataSamePlace({}); TestPrepareDataSamePlace({});
} }
TEST(test_prepare_op, test_complex_eager) {
NameVarMap<egr::EagerVariable> outs = {};
TestHandleComplexGradToRealGradEager(outs);
}
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) { TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
TestPrepareDataSamePlace({{"use_mkldnn", true}}); TestPrepareDataSamePlace({{"use_mkldnn", true}});
......
...@@ -37,9 +37,10 @@ namespace paddle { ...@@ -37,9 +37,10 @@ namespace paddle {
namespace imperative { namespace imperative {
using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>; using vb_vector = std::vector<std::shared_ptr<imperative::VarBase>>;
using var_pair = std::pair<std::string, vb_vector>; using var_pair = std::pair<std::string, vb_vector>;
using ev_vector = std::vector<std::shared_ptr<egr::EagerVariable>>;
using ev_pair = std::pair<std::string, ev_vector>;
TEST(test_tracer, test_trace_op) { TEST(test_tracer, test_trace_op) {
// Doing an mul // Doing an mul
imperative::Tracer tracer; imperative::Tracer tracer;
...@@ -546,6 +547,44 @@ TEST(test_tracer, test_execution_context) { ...@@ -546,6 +547,44 @@ TEST(test_tracer, test_execution_context) {
ASSERT_EQ(dy_ctx.OutputName("Out"), framework::kEmptyVarName); ASSERT_EQ(dy_ctx.OutputName("Out"), framework::kEmptyVarName);
} }
TEST(test_tracer, eager_tracer) {
// Doing an mul
imperative::Tracer tracer;
std::shared_ptr<egr::EagerVariable> x_in(new egr::EagerVariable("x_in"));
std::shared_ptr<egr::EagerVariable> y_in(new egr::EagerVariable("y_in"));
std::shared_ptr<egr::EagerVariable> vout(new egr::EagerVariable("vout"));
platform::CPUPlace place;
std::vector<float> src_data(10, 2.0);
std::vector<int64_t> dims1 = {2, 5};
std::vector<int64_t> dims2 = {5, 2};
auto* x_in_tensor = x_in->MutableVar()->GetMutable<framework::LoDTensor>();
auto* y_in_tensor = y_in->MutableVar()->GetMutable<framework::LoDTensor>();
x_in_tensor->Resize(framework::make_ddim(dims1));
auto* mutable_x = x_in_tensor->mutable_data<float>(place);
paddle::memory::Copy(place, mutable_x, place, src_data.data(),
sizeof(float) * src_data.size());
y_in_tensor->Resize(framework::make_ddim(dims2));
auto* mutable_y = y_in_tensor->mutable_data<float>(place);
paddle::memory::Copy(place, mutable_y, place, src_data.data(),
sizeof(float) * src_data.size());
ev_pair x_pair = ev_pair("X", ev_vector(1, x_in));
ev_pair y_pair = ev_pair("Y", ev_vector(1, y_in));
ev_pair out_pair = ev_pair("Out", ev_vector(1, vout));
imperative::NameTensorMap ins = {x_pair, y_pair};
imperative::NameTensorMap outs = {out_pair};
framework::AttributeMap mul_attr_map;
mul_attr_map["use_mkldnn"] = false;
tracer.TraceOp<egr::EagerVariable>("mul", ins, outs, mul_attr_map, place,
true);
const auto& out_tensor = vout->Var().Get<framework::LoDTensor>();
for (int i = 0; i < vout->Var().Get<framework::LoDTensor>().numel(); i++) {
ASSERT_EQ(out_tensor.data<float>()[i], 20.0);
}
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -553,4 +592,4 @@ USE_OP(mul); ...@@ -553,4 +592,4 @@ USE_OP(mul);
USE_OP(mul_grad); USE_OP(mul_grad);
USE_OP(reduce_sum); USE_OP(reduce_sum);
USE_OP(reduce_sum_grad); USE_OP(reduce_sum_grad);
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/op_base.h" #include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/platform/denormal.h" #include "paddle/fluid/platform/denormal.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -138,6 +139,17 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists( ...@@ -138,6 +139,17 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use MLU device since it's not compiled with MLU," "Paddle can't use MLU device since it's not compiled with MLU,"
"Please recompile or reinstall Paddle with MLU support.")); "Please recompile or reinstall Paddle with MLU support."));
#endif
} else if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
gc.reset(new framework::CustomDefaultStreamGarbageCollector(place, 0));
VLOG(10) << "Created GarbageCollector at " << place;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use CustomDevice since it's not compiled with "
"CustomDevice,"
"Please recompile or reinstall Paddle with CustomDevice "
"support."));
#endif #endif
} else { } else {
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
...@@ -156,7 +168,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins, ...@@ -156,7 +168,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
const platform::Place& place, bool trace_backward, const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map, const std::map<std::string, std::string>& inplace_map,
paddle::framework::AttributeMap* passed_default_attrs_, paddle::framework::AttributeMap* passed_default_attrs_,
bool override_default_attr_map) { bool use_default_attr_map) {
platform::RecordEvent op_type_record_event(type); platform::RecordEvent op_type_record_event(type);
platform::ScopedFlushDenormal flush; platform::ScopedFlushDenormal flush;
VLOG(1) << "Trace Op: " << type; VLOG(1) << "Trace Op: " << type;
...@@ -222,9 +234,17 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins, ...@@ -222,9 +234,17 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with MLU if use MLUPlace.")); "PaddlePaddle should compile with MLU if use MLUPlace."));
#endif
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::DeviceManager::SetDevice(place);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with CustomDevice if use "
"CustomPlace."));
#endif #endif
} }
if (!override_default_attr_map) { if (!use_default_attr_map) {
PADDLE_ENFORCE_NOT_NULL(passed_default_attrs_, PADDLE_ENFORCE_NOT_NULL(passed_default_attrs_,
paddle::platform::errors::PermissionDenied( paddle::platform::errors::PermissionDenied(
"Detected default_attrs = nullptr.")); "Detected default_attrs = nullptr."));
...@@ -260,16 +280,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins, ...@@ -260,16 +280,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap<VarType>& ins,
} }
if (ComputeRequiredGrad(new_ins, outs, trace_backward)) { if (ComputeRequiredGrad(new_ins, outs, trace_backward)) {
if (!override_default_attr_map) { PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_NOT_NULL(passed_default_attrs_, passed_default_attrs_, nullptr,
paddle::platform::errors::PermissionDenied( paddle::platform::errors::PermissionDenied(
"Detected default_attrs = nullptr.")); "We expect passed_default_attrs_ is nullptr while "
CreateGradOpNode(*op, new_ins, outs, attrs, *passed_default_attrs_, place, "use_default_attr_map is true, however we got not null "
inplace_map); "passed_default_attrs_. Please check your usage of trace_op. "));
} else { CreateGradOpNode(*op, new_ins, outs, attrs, default_attrs, place,
CreateGradOpNode(*op, new_ins, outs, attrs, default_attrs, place, inplace_map);
inplace_map);
}
} else { } else {
VLOG(3) << "No Grad to track for Op: " << type; VLOG(3) << "No Grad to track for Op: " << type;
} }
...@@ -281,16 +299,14 @@ template void Tracer::TraceOp<VarBase>( ...@@ -281,16 +299,14 @@ template void Tracer::TraceOp<VarBase>(
const NameVarMap<VarBase>& outs, framework::AttributeMap attrs, const NameVarMap<VarBase>& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward, const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map, const std::map<std::string, std::string>& inplace_map,
paddle::framework::AttributeMap* default_attrs, paddle::framework::AttributeMap* default_attrs, bool use_default_attr_map);
bool override_default_attr_map);
template void Tracer::TraceOp<egr::EagerTensor>( template void Tracer::TraceOp<egr::EagerVariable>(
const std::string& type, const NameVarMap<egr::EagerTensor>& ins, const std::string& type, const NameVarMap<egr::EagerVariable>& ins,
const NameVarMap<egr::EagerTensor>& outs, framework::AttributeMap attrs, const NameVarMap<egr::EagerVariable>& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward, const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map_, const std::map<std::string, std::string>& inplace_map_,
paddle::framework::AttributeMap* default_attrs, paddle::framework::AttributeMap* default_attrs, bool use_default_attr_map);
bool override_default_attr_map);
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, framework::AttributeMap attrs,
...@@ -304,13 +320,12 @@ void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins, ...@@ -304,13 +320,12 @@ void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
paddle::framework::AttributeMap attrs, paddle::framework::AttributeMap attrs,
const paddle::platform::Place& place, const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs, paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map, bool use_default_attr_map,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp with override_default_attr_map: " VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: "
<< override_default_attr_map; << use_default_attr_map;
TraceOp<egr::EagerTensor>(type, ins, outs, std::move(attrs), place, false, TraceOp<egr::EagerVariable>(type, ins, outs, std::move(attrs), place, false,
inplace_map, default_attrs, inplace_map, default_attrs, use_default_attr_map);
override_default_attr_map);
} }
void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins, void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
...@@ -318,8 +333,9 @@ void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins, ...@@ -318,8 +333,9 @@ void Tracer::TraceOp(const std::string& type, const NameTensorMap& ins,
paddle::framework::AttributeMap attrs, paddle::framework::AttributeMap attrs,
const std::map<std::string, std::string>& inplace_map) { const std::map<std::string, std::string>& inplace_map) {
VLOG(6) << "Running On Eager TraceOp(less): "; VLOG(6) << "Running On Eager TraceOp(less): ";
TraceOp<egr::EagerTensor>(type, ins, outs, std::move(attrs), expected_place_, TraceOp<egr::EagerVariable>(type, ins, outs, std::move(attrs),
false, inplace_map, nullptr, true); expected_place_, false, inplace_map, nullptr,
true);
} }
void Tracer::SetExpectedPlace(platform::Place place) { void Tracer::SetExpectedPlace(platform::Place place) {
......
...@@ -69,7 +69,7 @@ class Tracer { ...@@ -69,7 +69,7 @@ class Tracer {
const platform::Place& place, bool trace_backward, const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map = {}, const std::map<std::string, std::string>& inplace_map = {},
paddle::framework::AttributeMap* passed_default_attrs_ = nullptr, paddle::framework::AttributeMap* passed_default_attrs_ = nullptr,
bool override_default_attr_map = true); bool use_default_attr_map = true);
void TraceOp(const std::string& type, const NameVarBaseMap& ins, void TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs, const NameVarBaseMap& outs, framework::AttributeMap attrs,
...@@ -83,7 +83,7 @@ class Tracer { ...@@ -83,7 +83,7 @@ class Tracer {
const NameTensorMap& outs, paddle::framework::AttributeMap attrs, const NameTensorMap& outs, paddle::framework::AttributeMap attrs,
const paddle::platform::Place& place, const paddle::platform::Place& place,
paddle::framework::AttributeMap* default_attrs, paddle::framework::AttributeMap* default_attrs,
bool override_default_attr_map, bool use_default_attr_map,
const std::map<std::string, std::string>& inplace_map = {}); const std::map<std::string, std::string>& inplace_map = {});
bool ComputeRequiredGrad(const NameVarBaseMap& ins, bool ComputeRequiredGrad(const NameVarBaseMap& ins,
......
...@@ -95,8 +95,8 @@ template const paddle::platform::Place &GetPlace<VarBase>( ...@@ -95,8 +95,8 @@ template const paddle::platform::Place &GetPlace<VarBase>(
const std::shared_ptr<VarBase> &var); const std::shared_ptr<VarBase> &var);
template const paddle::platform::Place &GetPlace<VariableWrapper>( template const paddle::platform::Place &GetPlace<VariableWrapper>(
const std::shared_ptr<VariableWrapper> &var); const std::shared_ptr<VariableWrapper> &var);
template const paddle::platform::Place &GetPlace<egr::EagerTensor>( template const paddle::platform::Place &GetPlace<egr::EagerVariable>(
const std::shared_ptr<egr::EagerTensor> &var); const std::shared_ptr<egr::EagerVariable> &var);
/* GetNameFromVar */ /* GetNameFromVar */
template <typename VarType> template <typename VarType>
...@@ -104,8 +104,8 @@ const std::string &GetNameFromVar(std::shared_ptr<VarType> var) { ...@@ -104,8 +104,8 @@ const std::string &GetNameFromVar(std::shared_ptr<VarType> var) {
return var->Name(); return var->Name();
} }
template <> template <>
const std::string &GetNameFromVar<egr::EagerTensor>( const std::string &GetNameFromVar<egr::EagerVariable>(
std::shared_ptr<egr::EagerTensor> tensor) { std::shared_ptr<egr::EagerVariable> tensor) {
return tensor->name(); return tensor->name();
} }
template const std::string &GetNameFromVar<VariableWrapper>( template const std::string &GetNameFromVar<VariableWrapper>(
...@@ -120,8 +120,8 @@ void SetType(std::shared_ptr<VarType> var, ...@@ -120,8 +120,8 @@ void SetType(std::shared_ptr<VarType> var,
var->SetType(type); var->SetType(type);
} }
template <> template <>
void SetType<egr::EagerTensor>(std::shared_ptr<egr::EagerTensor> var, void SetType<egr::EagerVariable>(std::shared_ptr<egr::EagerVariable> var,
framework::proto::VarType::Type type) { framework::proto::VarType::Type type) {
switch (type) { switch (type) {
case paddle::framework::proto::VarType::LOD_TENSOR: { case paddle::framework::proto::VarType::LOD_TENSOR: {
var->MutableVar()->GetMutable<paddle::framework::LoDTensor>(); var->MutableVar()->GetMutable<paddle::framework::LoDTensor>();
...@@ -149,8 +149,8 @@ framework::proto::VarType::Type GetType(std::shared_ptr<VarType> var) { ...@@ -149,8 +149,8 @@ framework::proto::VarType::Type GetType(std::shared_ptr<VarType> var) {
return var->Type(); return var->Type();
} }
template <> template <>
framework::proto::VarType::Type GetType<egr::EagerTensor>( framework::proto::VarType::Type GetType<egr::EagerVariable>(
std::shared_ptr<egr::EagerTensor> var) { std::shared_ptr<egr::EagerVariable> var) {
if (var->Var().IsInitialized()) { if (var->Var().IsInitialized()) {
return paddle::framework::ToVarType(var->Var().Type()); return paddle::framework::ToVarType(var->Var().Type());
} else { } else {
...@@ -168,8 +168,8 @@ framework::proto::VarType::Type GetDataType(std::shared_ptr<VarType> var) { ...@@ -168,8 +168,8 @@ framework::proto::VarType::Type GetDataType(std::shared_ptr<VarType> var) {
return var->DataType(); return var->DataType();
} }
template <> template <>
framework::proto::VarType::Type GetDataType<egr::EagerTensor>( framework::proto::VarType::Type GetDataType<egr::EagerVariable>(
std::shared_ptr<egr::EagerTensor> var) { std::shared_ptr<egr::EagerVariable> var) {
if (var->Var().IsType<pten::SelectedRows>()) { if (var->Var().IsType<pten::SelectedRows>()) {
return framework::TransToProtoVarType( return framework::TransToProtoVarType(
var->Var().Get<pten::SelectedRows>().value().type()); var->Var().Get<pten::SelectedRows>().value().type());
...@@ -197,8 +197,8 @@ bool CheckCachedKey(std::shared_ptr<VarType> var, ...@@ -197,8 +197,8 @@ bool CheckCachedKey(std::shared_ptr<VarType> var,
return GetVariableWrapper(var)->hasCacheKey(key); return GetVariableWrapper(var)->hasCacheKey(key);
} }
template <> template <>
bool CheckCachedKey<egr::EagerTensor>( bool CheckCachedKey<egr::EagerVariable>(
std::shared_ptr<egr::EagerTensor> tensor, std::shared_ptr<egr::EagerVariable> tensor,
const paddle::framework::OpKernelType &key) { const paddle::framework::OpKernelType &key) {
// TODO(jiabin): Support this later // TODO(jiabin): Support this later
// VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is
...@@ -219,7 +219,7 @@ std::shared_ptr<VariableWrapper> GetCachedValue( ...@@ -219,7 +219,7 @@ std::shared_ptr<VariableWrapper> GetCachedValue(
} }
template <> template <>
std::shared_ptr<VariableWrapper> GetCachedValue( std::shared_ptr<VariableWrapper> GetCachedValue(
std::shared_ptr<egr::EagerTensor> var, std::shared_ptr<egr::EagerVariable> var,
const paddle::framework::OpKernelType &key) { const paddle::framework::OpKernelType &key) {
// TODO(jiabin): Support this later // TODO(jiabin): Support this later
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
...@@ -243,10 +243,10 @@ void SetCachedValue(std::shared_ptr<VarType> var, ...@@ -243,10 +243,10 @@ void SetCachedValue(std::shared_ptr<VarType> var,
GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res)); GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res));
} }
template <> template <>
void SetCachedValue<egr::EagerTensor>( void SetCachedValue<egr::EagerVariable>(
std::shared_ptr<egr::EagerTensor> tensor, std::shared_ptr<egr::EagerVariable> tensor,
const paddle::framework::OpKernelType &key, const paddle::framework::OpKernelType &key,
std::shared_ptr<egr::EagerTensor> res) { std::shared_ptr<egr::EagerVariable> res) {
// PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not
// reach this, support cache and remove this error check later, or this // reach this, support cache and remove this error check later, or this
// should not be supported.")); // should not be supported."));
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
namespace egr { namespace egr {
class EagerTensor; class EagerVariable;
} // namespace egr } // namespace egr
namespace pten { namespace pten {
class DenseTensor; class DenseTensor;
......
...@@ -379,8 +379,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp( ...@@ -379,8 +379,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
trt_engine->SetUseInspector(Get<bool>("use_inspector")); trt_engine->SetUseInspector(Get<bool>("use_inspector"));
trt_engine->SetWithErnie( trt_engine->SetWithErnie(
graph->Has(framework::ir::kEmbEltwiseLayernormPass) && (graph->Has(framework::ir::kEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)); graph->Has(framework::ir::kMultiheadMatmulPass)) ||
(graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) &&
graph->Has(framework::ir::kMultiheadMatmulPass)));
if (use_static_engine) { if (use_static_engine) {
trt_engine_serialized_data = GetTrtEngineSerializedData( trt_engine_serialized_data = GetTrtEngineSerializedData(
......
...@@ -1470,6 +1470,8 @@ USE_TRT_CONVERTER(conv3d_transpose); ...@@ -1470,6 +1470,8 @@ USE_TRT_CONVERTER(conv3d_transpose);
USE_TRT_CONVERTER(mish); USE_TRT_CONVERTER(mish);
USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(deformable_conv);
USE_TRT_CONVERTER(pool3d) USE_TRT_CONVERTER(pool3d)
USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
USE_TRT_CONVERTER(preln_skip_layernorm)
#endif #endif
namespace paddle_infer { namespace paddle_infer {
......
...@@ -82,22 +82,24 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -82,22 +82,24 @@ const std::vector<std::string> kTRTSubgraphPasses({
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", // "delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", // "delete_quant_dequant_filter_op_pass", //
// "fc_fuse_pass", // // "fc_fuse_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "preln_embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v3", // "multihead_matmul_fuse_pass_v2", //
"skip_layernorm_fuse_pass", // "multihead_matmul_fuse_pass_v3", //
"conv_bn_fuse_pass", // "skip_layernorm_fuse_pass", //
"unsqueeze2_eltwise_fuse_pass", // "preln_skip_layernorm_fuse_pass", //
"trt_squeeze2_matmul_fuse_pass", // "conv_bn_fuse_pass", //
"trt_reshape2_matmul_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", //
"trt_flatten2_matmul_fuse_pass", // "trt_squeeze2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_mul_pass", // "trt_reshape2_matmul_fuse_pass", //
"trt_map_matmul_v2_to_matmul_pass", // "trt_flatten2_matmul_fuse_pass", //
"trt_map_matmul_to_mul_pass", // "trt_map_matmul_v2_to_mul_pass", //
"fc_fuse_pass", // "trt_map_matmul_v2_to_matmul_pass", //
"conv_elementwise_add_fuse_pass", // "trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"add_support_int8_pass", "add_support_int8_pass",
"tensorrt_subgraph_pass", // "tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
......
...@@ -21,6 +21,8 @@ nv_library(tensorrt_converter ...@@ -21,6 +21,8 @@ nv_library(tensorrt_converter
nearest_interp_v2_op.cc nearest_interp_v2_op.cc
pool3d_op.cc pool3d_op.cc
deformable_conv_op.cc deformable_conv_op.cc
preln_emb_eltwise_layernorm.cc
preln_skip_layernorm.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
...@@ -43,30 +43,161 @@ class GeluOpConverter : public OpConverter { ...@@ -43,30 +43,161 @@ class GeluOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid gelu op to tensorrt gelu layer"; VLOG(4) << "convert fluid gelu op to tensorrt gelu layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
int input_num = op_desc.Input("X").size();
auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (op_desc.HasAttr("approximate") &&
#if IS_TRT_VERSION_GE(6000) BOOST_GET_CONST(bool, op_desc.GetAttr("approximate"))) {
bool with_fp16 = #if IS_TRT_VERSION_GE(7000)
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); nvinfer1::Dims input_shape;
plugin::GeluPluginDynamic* plugin = input_shape.nbDims = input->getDimensions().nbDims;
new plugin::GeluPluginDynamic(with_fp16); for (int i = 0; i < input_shape.nbDims; ++i) {
layer = engine_->AddDynamicPlugin(&input, input_num, plugin); input_shape.d[i] = 1;
}
std::string out_name = op_desc.Output("Out").front();
auto create_weights = [&](float data, std::string type) -> float* {
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
tmp_tensor->Resize({1});
auto* tmp_data = tmp_tensor->mutable_data<float>(platform::CPUPlace());
tmp_data[0] = data;
engine_->SetWeights(out_name + "_gelu_op_" + type,
std::move(tmp_tensor));
return tmp_data;
};
float* constant_pow = create_weights(3.0f, "constant_pow");
float* constant_multiply = create_weights(0.044715f, "constant_multiply");
float* constant_sqrt =
create_weights(0.79788456080286535587989211986876f, "constant_sqrt");
float* constant_one = create_weights(1.0f, "constant_one");
float* constant_half = create_weights(0.5f, "constant_half");
auto constant_layer_pow = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_pow), 1});
auto constant_layer_multiply = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_multiply), 1});
auto constant_layer_sqrt = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_sqrt), 1});
auto constant_layer_one = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_one), 1});
auto constant_layer_half = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_half), 1});
auto layer_pow = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *input, *constant_layer_pow->getOutput(0),
nvinfer1::ElementWiseOperation::kPOW);
auto layer_mul =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_pow->getOutput(0),
*constant_layer_multiply->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
auto layer_add =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_mul->getOutput(0),
*input, nvinfer1::ElementWiseOperation::kSUM);
auto layer_sqrt =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_add->getOutput(0),
*constant_layer_sqrt->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
auto layer_tanh =
TRT_ENGINE_ADD_LAYER(engine_, Activation, *layer_sqrt->getOutput(0),
nvinfer1::ActivationType::kTANH);
auto layer_one =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_tanh->getOutput(0),
*constant_layer_one->getOutput(0),
nvinfer1::ElementWiseOperation::kSUM);
auto layer_CDF =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_one->getOutput(0),
*constant_layer_half->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
auto y =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_CDF->getOutput(0),
*input, nvinfer1::ElementWiseOperation::kPROD);
layer = y;
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that " "You are running GeLU Op with approximate True, need to confirm that "
"your TRT version is no less than 6.0")); "your TRT version is no less than 7.0"));
#endif #endif
} else { } else {
bool with_fp16 = #if IS_TRT_VERSION_GE(7000)
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); nvinfer1::Dims input_shape;
plugin::GeluPlugin* plugin = new plugin::GeluPlugin(with_fp16); input_shape.nbDims = input->getDimensions().nbDims;
layer = engine_->AddPlugin(&input, input_num, plugin); for (int i = 0; i < input_shape.nbDims; ++i) {
input_shape.d[i] = 1;
}
std::string out_name = op_desc.Output("Out").front();
auto create_weights = [&](float data, std::string type) -> float* {
std::unique_ptr<framework::Tensor> tmp_tensor(new framework::Tensor());
tmp_tensor->Resize({1});
auto* tmp_data = tmp_tensor->mutable_data<float>(platform::CPUPlace());
tmp_data[0] = data;
engine_->SetWeights(out_name + "_gelu_op_" + type,
std::move(tmp_tensor));
return tmp_data;
};
float* constant_one = create_weights(1.0f, "constant_one");
float* constant_half = create_weights(0.5f, "constant_half");
float* constant_rsqrt2 =
create_weights(0.70710678118f, "constant_rsqrt2");
auto constant_layer_one = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_one), 1});
auto constant_layer_half = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_half), 1});
auto constant_layer_rsqrt2 = TRT_ENGINE_ADD_LAYER(
engine_, Constant, input_shape,
nvinfer1::Weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(constant_rsqrt2), 1});
auto layer_mul = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *input, *constant_layer_rsqrt2->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
auto layer_erf =
TRT_ENGINE_ADD_LAYER(engine_, Unary, *layer_mul->getOutput(0),
nvinfer1::UnaryOperation::kERF);
auto layer_add =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_erf->getOutput(0),
*constant_layer_one->getOutput(0),
nvinfer1::ElementWiseOperation::kSUM);
auto layer_CDF =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_add->getOutput(0),
*constant_layer_half->getOutput(0),
nvinfer1::ElementWiseOperation::kPROD);
auto y =
TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *layer_CDF->getOutput(0),
*input, nvinfer1::ElementWiseOperation::kPROD);
layer = y;
#else // if IS_TRT_VERSION_GE(7000)
int input_num = op_desc.Input("X").size();
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::GeluPluginDynamic* plugin =
new plugin::GeluPluginDynamic(with_fp16);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::GeluPlugin* plugin = new plugin::GeluPlugin(with_fp16);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
#endif // if IS_TRT_VERSION_GE(7000)
} }
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "gelu", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "gelu", {output_name}, test_mode);
......
...@@ -106,6 +106,9 @@ class Pool2dOpConverter : public OpConverter { ...@@ -106,6 +106,9 @@ class Pool2dOpConverter : public OpConverter {
reduce_operation = nvinfer1::ReduceOperation::kAVG; reduce_operation = nvinfer1::ReduceOperation::kAVG;
plugin_pool_type = plugin::PoolPlugin::PoolType::avg; plugin_pool_type = plugin::PoolPlugin::PoolType::avg;
} }
if (global_pooling || adaptive) {
std::fill(paddings.begin(), paddings.end(), 0);
}
if (padding_algorithm == "VALID") { if (padding_algorithm == "VALID") {
std::fill(paddings.begin(), paddings.end(), 0); std::fill(paddings.begin(), paddings.end(), 0);
...@@ -136,6 +139,46 @@ class Pool2dOpConverter : public OpConverter { ...@@ -136,6 +139,46 @@ class Pool2dOpConverter : public OpConverter {
#endif #endif
} }
std::vector<int> real_paddings = paddings;
for (int i = 0; i < 2; ++i) {
int copy_pad = *(paddings.begin() + i);
real_paddings.insert(real_paddings.begin() + 2 * i + 1, copy_pad);
}
// SAME
if (padding_algorithm == "SAME") {
// expand
for (int i = 0; i < 2; ++i) {
int copy_pad = *(paddings.begin() + 2 * i);
paddings.insert(paddings.begin() + 2 * i + 1, copy_pad);
}
// compute
for (int i = 0; i < 2; ++i) {
int out_size = (input_shape.d[2 + i] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i] - input_shape.d[2 + i], 0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
paddings[i * 2] = pad_0;
paddings[i * 2 + 1] = pad_1;
}
real_paddings = paddings;
// slice
for (int i = 0; i < 2; ++i) {
paddings.erase(paddings.begin() + i + 1);
}
}
// VALID
if (padding_algorithm == "VALID") {
std::fill(real_paddings.begin(), real_paddings.end(), 0);
}
if (global_pooling == true && !engine_->with_dynamic_shape()) {
nv_ksize.d[0] = input_shape.d[input_dims - 2];
nv_ksize.d[1] = input_shape.d[input_dims - 1];
ksize[0] = input_shape.d[input_dims - 2];
ksize[1] = input_shape.d[input_dims - 1];
}
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (!adaptive && !global_pooling && !ceil_mode) { if (!adaptive && !global_pooling && !ceil_mode) {
// input_shape.d < 0 means we can't get shape info here. // input_shape.d < 0 means we can't get shape info here.
...@@ -173,15 +216,15 @@ class Pool2dOpConverter : public OpConverter { ...@@ -173,15 +216,15 @@ class Pool2dOpConverter : public OpConverter {
pool_layer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP); pool_layer->setPaddingMode(nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP);
} }
layer = pool_layer; layer = pool_layer;
} else if (global_pooling) { } else if (global_pooling && !adaptive) {
auto *reduce_layer = TRT_ENGINE_ADD_LAYER(engine_, Reduce, *input1, auto *reduce_layer = TRT_ENGINE_ADD_LAYER(engine_, Reduce, *input1,
reduce_operation, 12, true); reduce_operation, 12, true);
layer = reduce_layer; layer = reduce_layer;
} else { } else {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
plugin::PoolPluginDynamic *plugin = plugin::PoolPluginDynamic *plugin = new plugin::PoolPluginDynamic(
new plugin::PoolPluginDynamic(ceil_mode, pool_type, adaptive, ksize, ceil_mode, pool_type, adaptive, exclusive, ksize, strides, paddings,
strides, paddings, global_pooling); global_pooling);
layer = engine_->AddDynamicPlugin(&input1, 1, plugin); layer = engine_->AddDynamicPlugin(&input1, 1, plugin);
#endif #endif
} }
...@@ -195,21 +238,13 @@ class Pool2dOpConverter : public OpConverter { ...@@ -195,21 +238,13 @@ class Pool2dOpConverter : public OpConverter {
return; return;
} }
if (global_pooling == true) { if (global_pooling == true && adaptive == false) {
nv_ksize.d[0] = input_shape.d[input_dims - 2];
nv_ksize.d[1] = input_shape.d[input_dims - 1];
auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1, auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1,
nv_pool_type, nv_ksize); nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
pool_layer, platform::errors::Fatal( pool_layer, platform::errors::Fatal(
"trt pool layer in converter could not be created.")); "trt pool layer in converter could not be created."));
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
pool_layer->setStride(nv_strides);
pool_layer->setPadding(nv_paddings);
if (padding_algorithm == "SAME") {
pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
pool_layer->setAverageCountExcludesPadding(exclusive);
pool_layer->setName(("pool2d (Output: " + output_name + ")").c_str()); pool_layer->setName(("pool2d (Output: " + output_name + ")").c_str());
pool_layer->getOutput(0)->setName(output_name.c_str()); pool_layer->getOutput(0)->setName(output_name.c_str());
engine_->SetITensor(output_name, pool_layer->getOutput(0)); engine_->SetITensor(output_name, pool_layer->getOutput(0));
...@@ -222,58 +257,61 @@ class Pool2dOpConverter : public OpConverter { ...@@ -222,58 +257,61 @@ class Pool2dOpConverter : public OpConverter {
if (!adaptive) { if (!adaptive) {
if (ceil_mode) { if (ceil_mode) {
nvinfer1::DimsHW pre_pad(0, 0); std::vector<int> input_shape_v;
nvinfer1::DimsHW post_pad(0, 0); for (int i = 0; i < input_dims; i++) {
// If ceil mode is true, we will pad the appropriate size to the input. input_shape_v.push_back(input_shape.d[i]);
DealCeilMode(input_shape, ksize, strides, paddings, &pre_pad, &post_pad, }
input_dims); plugin::PoolPlugin *plugin = new plugin::PoolPlugin(
auto *pad_layer = ceil_mode, plugin_pool_type, adaptive, exclusive, ksize, strides,
TRT_ENGINE_ADD_LAYER(engine_, Padding, *input1, pre_pad, post_pad); paddings, input_shape_v, real_paddings);
auto *pool_layer = engine_->AddPlugin(&input1, 1, plugin);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
pad_layer, platform::errors::Fatal( pool_layer,
"Pad layer in poolOp converter could not be " platform::errors::Fatal(
"created. The pointer to pad layer is `NULL`.")); "trt pool plugin layer in converter could not be created."));
input1 = pad_layer->getOutput(0); layer = pool_layer;
} } else {
#if IS_TRT_VERSION_GE(8000) #if IS_TRT_VERSION_GE(8000)
// Exclude padding pixels from the average mean is not supported well by // Exclude padding pixels from the average mean is not supported well by
// TRT // TRT
// so enable padding for trt8.0 above. // so enable padding for trt8.0 above.
if ((g_post_pad.w() > 0 || g_post_pad.h() > 0) && if ((g_post_pad.w() > 0 || g_post_pad.h() > 0) &&
(padding_algorithm != "SAME") && !ceil_mode) { (padding_algorithm != "SAME") && !ceil_mode) {
auto *pad_layer = TRT_ENGINE_ADD_LAYER(engine_, Padding, *input1, auto *pad_layer = TRT_ENGINE_ADD_LAYER(engine_, Padding, *input1,
g_pre_pad, g_post_pad); g_pre_pad, g_post_pad);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
pad_layer, platform::errors::Fatal( pad_layer, platform::errors::Fatal(
"Pad layer in poolOp converter could not be " "Pad layer in poolOp converter could not be "
"created. The pointer to pad layer is `NULL`.")); "created. The pointer to pad layer is `NULL`."));
input1 = pad_layer->getOutput(0); input1 = pad_layer->getOutput(0);
} }
#endif #endif
auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1, auto *pool_layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling, *input1,
nv_pool_type, nv_ksize); nv_pool_type, nv_ksize);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
pool_layer, platform::errors::Fatal( pool_layer,
"trt pool layer in converter could not be created.")); platform::errors::Fatal(
pool_layer->setStride(nv_strides); "trt pool layer in converter could not be created."));
pool_layer->setPadding(nv_paddings); pool_layer->setStride(nv_strides);
if (padding_algorithm == "SAME") { pool_layer->setPadding(nv_paddings);
pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER); if (padding_algorithm == "SAME") {
pool_layer->setPaddingMode(nvinfer1::PaddingMode::kSAME_UPPER);
}
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer;
} }
pool_layer->setAverageCountExcludesPadding(exclusive);
layer = pool_layer;
} else { } else {
// Average pooling needs to exclude the padding pixels from the average // Average pooling needs to exclude the padding pixels from the average
// mean. // mean.
// It is not supported well by TRT, we use a plugin here. // It is not supported well by TRT, we use a plugin here
std::vector<int> input_shape_v; std::vector<int> input_shape_v;
for (int i = 0; i < input_dims; i++) { for (int i = 0; i < input_dims; i++) {
input_shape_v.push_back(input_shape.d[i]); input_shape_v.push_back(input_shape.d[i]);
} }
plugin::PoolPlugin *plugin = plugin::PoolPlugin *plugin = new plugin::PoolPlugin(
new plugin::PoolPlugin(ceil_mode, plugin_pool_type, adaptive, ksize, ceil_mode, plugin_pool_type, adaptive, exclusive, ksize, strides,
strides, paddings, input_shape_v); paddings, input_shape_v, real_paddings);
auto *pool_layer = engine_->AddPlugin(&input1, 1, plugin); auto *pool_layer = engine_->AddPlugin(&input1, 1, plugin);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
pool_layer, pool_layer,
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fluid PrelnEmbEltwiseLayerNorm op to tensorrt layer";
if (!(engine_->use_oss() && engine_->with_interleaved())) {
PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved"));
}
framework::OpDesc op_desc(op, nullptr);
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (!enable_int8) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
auto word_id_name = op_desc.Input("WordId").front();
auto pos_id_name = op_desc.Input("PosId").front();
engine_->Set("ernie_pos_name", new std::string(pos_id_name));
auto sent_id_name = op_desc.Input("SentId").front();
auto word_emb_name = op_desc.Input("WordEmbedding").front();
auto pos_emb_name = op_desc.Input("PosEmbedding").front();
auto sent_emb_name = op_desc.Input("SentEmbedding").front();
std::vector<std::string> id_names;
std::vector<std::string> emb_names;
id_names =
std::vector<std::string>{word_id_name, pos_id_name, sent_id_name};
emb_names =
std::vector<std::string>{word_emb_name, pos_emb_name, sent_emb_name};
int input_num = id_names.size();
// Declare inputs
std::vector<nvinfer1::ITensor*> input_ids;
for (int i = 0; i < input_num; i++) {
input_ids.push_back(engine_->GetITensor(id_names[i]));
}
// input_embs[0]: word_embedding
// input_embs[1]: pos_embedding
// input_embs[2]: sent_embedding
std::vector<float*> input_embs;
std::vector<int> emb_sizes;
// get the presistable var's data
auto get_persistable_data = [&](const std::string& var_name,
framework::DDim* dims) -> float* {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
return temp_data;
};
for (int i = 0; i < input_num; i++) {
framework::DDim emb_dims;
float* emb_data = get_persistable_data(emb_names[i], &emb_dims);
int64_t emb_size = framework::product(emb_dims);
input_embs.push_back(emb_data);
emb_sizes.push_back(emb_size);
PADDLE_ENFORCE_EQ(
emb_dims.size(), 2,
platform::errors::InvalidArgument(
"The fused PrelnEmbEltwiseLayerNorm's emb should be 2 dims."));
}
framework::DDim bias_dims, scale_dims;
auto* bias =
get_persistable_data(op_desc.Input("Bias").front(), &bias_dims);
auto* scale =
get_persistable_data(op_desc.Input("Scale").front(), &scale_dims);
int64_t bias_size = framework::product(bias_dims);
int64_t scale_size = framework::product(scale_dims);
int output_int8 = 1;
PADDLE_ENFORCE_EQ(
input_num, 3,
platform::errors::InvalidArgument(
"When using oss and var-len, embedding_eltwise_layernorm op"
"should have 3 inputs only, but got %d.",
input_num));
const std::vector<nvinfer1::PluginField> fields{
{"bert_embeddings_layernorm_beta", bias,
nvinfer1::PluginFieldType::kFLOAT32, static_cast<int32_t>(bias_size)},
{"bert_embeddings_layernorm_gamma", scale,
nvinfer1::PluginFieldType::kFLOAT32, static_cast<int32_t>(scale_size)},
{"bert_embeddings_word_embeddings", input_embs[0],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[0])},
{"bert_embeddings_token_type_embeddings", input_embs[2],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[2])},
{"bert_embeddings_position_embeddings", input_embs[1],
nvinfer1::PluginFieldType::kFLOAT32,
static_cast<int32_t>(emb_sizes[1])},
{"output_int8", &output_int8, nvinfer1::PluginFieldType::kINT32, 1},
};
nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_ptr) +
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data();
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(
engine_->GetITensor(word_id_name)); // word_embedding,
// eval_placeholder_0
plugin_inputs.emplace_back(
engine_->GetITensor(sent_id_name)); // sent_embedding,
// eval_placeholder_1
plugin_inputs.emplace_back(
engine_->GetITensor(pos_id_name)); // cu_seqlens,
// eval_placeholder_2
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto* shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *max_seqlen_tensor);
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
shuffle_layer->setName(
("PrelnEmbeltwise_Shuffle_reshape (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomEmbLayerNormPluginDynamic", "3");
auto plugin_obj =
creator->createPlugin("CustomEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(("CustomPrelnEmbLayerNormPluginDynamic_V3(Output: " +
op_desc.Output("Out")[0] + ")")
.c_str());
free(plugin_ptr);
float out_0_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_0_threshold"));
float out_1_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_1_threshold"));
engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), out_0_scale);
engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_1_scale);
auto* shuffler_embed_out0 =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(plugin_layer->getOutput(0)));
nvinfer1::Permutation transpose_0{2, 1, 0, 3};
shuffler_embed_out0->setSecondTranspose(transpose_0);
shuffler_embed_out0->getOutput(0)->setName(
op_desc.Output("Out_0")[0].c_str());
engine_->SetITensor(op_desc.Output("Out_0")[0],
shuffler_embed_out0->getOutput(0));
shuffler_embed_out0->setName(
("shuffler_after_CustomPrelnEmbLayerNormPluginDynamic_V3(Output_0: " +
op_desc.Output("Out_0")[0] + ")")
.c_str());
auto* shuffler_embed_out1 =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(plugin_layer->getOutput(1)));
nvinfer1::Permutation transpose_1{2, 1, 0, 3};
shuffler_embed_out1->setSecondTranspose(transpose_1);
shuffler_embed_out1->getOutput(0)->setName(
op_desc.Output("Out_1")[0].c_str());
engine_->SetITensor(op_desc.Output("Out_1")[0],
shuffler_embed_out1->getOutput(0));
shuffler_embed_out1->setName(
("shuffler_after_CustomPrelnEmbLayerNormPluginDynamic_V3(Output_1: " +
op_desc.Output("Out_1")[0] + ")")
.c_str());
#else
PADDLE_THROW(platform::errors::Fatal(
"PreInErnie want to use oss, must be with interleaved, "
"your TRT version is no less than 7.0"));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(fused_preln_embedding_eltwise_layernorm,
PrelnEmbEltwiseLayerNormOpConverter);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class PrelnSkipLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(7000)
VLOG(4) << "convert fused preln_skip_layernorm op to tensorrt layer";
if (!(engine_->use_oss() && engine_->with_interleaved())) {
PADDLE_THROW(platform::errors::Fatal(
"PrelnErnie: If you want to use oss, must be with interleaved"));
}
framework::OpDesc op_desc(op, nullptr);
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (!enable_int8) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
std::vector<nvinfer1::ITensor*> inputs;
inputs.push_back(input1);
inputs.push_back(input2);
auto get_persistable_data = [&](const std::string& arg_name,
framework::DDim* dims) -> float* {
std::string var_name = op_desc.Input(arg_name).front();
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
return temp_data;
};
framework::DDim bias_dims, scale_dims;
auto* bias = get_persistable_data("Bias", &bias_dims);
auto* scale = get_persistable_data("Scale", &scale_dims);
int bias_size = framework::product(bias_dims);
int scale_size = framework::product(scale_dims);
nvinfer1::ILayer* layer = nullptr;
VLOG(4) << "fused preln_skip_layernorm op: use_oss and with_interleaved";
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "4");
PADDLE_ENFORCE_NE(
creator, nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomPrelnSkipLayerNormPluginDynamic"));
const std::vector<nvinfer1::PluginField> fields{
{"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
{ "gamma",
scale,
nvinfer1::PluginFieldType::kFLOAT32,
scale_size }};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() * sizeof(nvinfer1::PluginField)));
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj =
creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr);
auto plugin_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
platform::errors::InvalidArgument(
"fail to add CustomPrelnSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "preln_skip_layernorm", {output_name},
test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
"PreInErnie want to use oss, must be with interleaved, "
"your TRT version is no less than 7.0"));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(preln_skip_layernorm, PrelnSkipLayerNormOpConverter);
...@@ -103,5 +103,5 @@ TEST(elementwise_op, plugin) { ...@@ -103,5 +103,5 @@ TEST(elementwise_op, plugin) {
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(elementwise_add); USE_OP_ITSELF(elementwise_add);
USE_OP(elementwise_mul); USE_OP(elementwise_mul);
...@@ -30,24 +30,6 @@ namespace tensorrt { ...@@ -30,24 +30,6 @@ namespace tensorrt {
// Just tell by the op_types. // Just tell by the op_types.
struct SimpleOpTypeSetTeller : public Teller { struct SimpleOpTypeSetTeller : public Teller {
SimpleOpTypeSetTeller() { SimpleOpTypeSetTeller() {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
teller_set.insert("hard_sigmoid");
teller_set.insert("clip");
int8_teller_set.insert("relu6");
int8_teller_set.insert("hard_sigmoid");
int8_teller_set.insert("clip");
#endif
#if IS_TRT_VERSION_GE(6000)
teller_set.insert("fused_embedding_eltwise_layernorm");
teller_set.insert("multihead_matmul");
teller_set.insert("skip_layernorm");
teller_set.insert("slice");
int8_teller_set.insert("fused_embedding_eltwise_layernorm");
int8_teller_set.insert("multihead_matmul");
int8_teller_set.insert("skip_layernorm");
int8_teller_set.insert("slice");
#endif
// TODO(baoachun) The group_norm trt plugin will check input's dim // TODO(baoachun) The group_norm trt plugin will check input's dim
// not -1 failed when dynamic shape mode. // not -1 failed when dynamic shape mode.
// #if IS_TRT_VERSION_GE(7130) // #if IS_TRT_VERSION_GE(7130)
...@@ -76,104 +58,124 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -76,104 +58,124 @@ struct SimpleOpTypeSetTeller : public Teller {
private: private:
// use this set for no calib int8. // use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{"mul", std::unordered_set<std::string> int8_teller_set{
"matmul", "mul",
"conv2d", "matmul",
"conv2d_fusion", "conv2d",
"pool2d", "conv2d_fusion",
"relu", "pool2d",
"softmax", "relu",
"sigmoid", "softmax",
"hard_swish", "sigmoid",
"depthwise_conv2d", "hard_swish",
"batch_norm", "depthwise_conv2d",
"concat", "batch_norm",
"tanh", "concat",
"pad", "tanh",
"elementwise_add", "pad",
"elementwise_mul", "elementwise_add",
"dropout", "elementwise_mul",
"prelu", "dropout",
"conv2d_transpose", "prelu",
"depthwise_conv2d_transpose", "conv2d_transpose",
"leaky_relu", "depthwise_conv2d_transpose",
"fc", "leaky_relu",
"shuffle_channel", "fc",
"swish", "shuffle_channel",
"split", "swish",
"instance_norm", "split",
"gelu", "instance_norm",
"layer_norm", "gelu",
"scale", "layer_norm",
"stack", "scale",
"transpose2", "stack",
"transpose", "transpose2",
"flatten2", "transpose",
"flatten", "flatten2",
"gather", "flatten",
"gather_nd", "gather",
"yolo_box", "gather_nd",
"roi_align", "yolo_box",
"affine_channel", "roi_align",
"nearest_interp", "affine_channel",
"anchor_generator", "nearest_interp",
"reduce_sum", "anchor_generator",
"reduce_mean", "reduce_sum",
"conv3d", "reduce_mean",
"conv3d_transpose", "conv3d",
"mish", "conv3d_transpose",
"nearest_interp_v2", "mish",
"pool3d", "nearest_interp_v2",
"deformable_conv"}; "pool3d",
std::unordered_set<std::string> teller_set{"mul", "deformable_conv",
"matmul", "relu6",
"conv2d", "hard_sigmoid",
"conv2d_fusion", "clip",
"pool2d", "fused_embedding_eltwise_layernorm",
"relu", "multihead_matmul",
"softmax", "skip_layernorm",
"sigmoid", "slice",
"hard_swish", "fused_preln_embedding_eltwise_layernorm",
"depthwise_conv2d", "preln_skip_layernorm"};
"batch_norm", std::unordered_set<std::string> teller_set{
"concat", "mul",
"tanh", "matmul",
"pad", "conv2d",
"elementwise_add", "conv2d_fusion",
"elementwise_mul", "pool2d",
"dropout", "relu",
"prelu", "softmax",
"conv2d_transpose", "sigmoid",
"depthwise_conv2d_transpose", "hard_swish",
"leaky_relu", "depthwise_conv2d",
"fc", "batch_norm",
"shuffle_channel", "concat",
"swish", "tanh",
"split", "pad",
"instance_norm", "elementwise_add",
"gelu", "elementwise_mul",
"layer_norm", "dropout",
"scale", "prelu",
"stack", "conv2d_transpose",
"transpose2", "depthwise_conv2d_transpose",
"transpose", "leaky_relu",
"flatten2", "fc",
"flatten", "shuffle_channel",
"gather", "swish",
"gather_nd", "split",
"yolo_box", "instance_norm",
"roi_align", "gelu",
"affine_channel", "layer_norm",
"nearest_interp", "scale",
"anchor_generator", "stack",
"reduce_sum", "transpose2",
"reduce_mean", "transpose",
"conv3d", "flatten2",
"conv3d_transpose", "flatten",
"mish", "gather",
"nearest_interp_v2", "gather_nd",
"pool3d", "yolo_box",
"deformable_conv"}; "roi_align",
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_mean",
"conv3d",
"conv3d_transpose",
"mish",
"nearest_interp_v2",
"pool3d",
"deformable_conv",
"relu6",
"hard_sigmoid",
"clip",
"fused_embedding_eltwise_layernorm",
"multihead_matmul",
"skip_layernorm",
"slice",
"fused_preln_embedding_eltwise_layernorm",
"preln_skip_layernorm"};
}; };
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
...@@ -1007,6 +1009,24 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1007,6 +1009,24 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
} }
if (op_type == "fused_preln_embedding_eltwise_layernorm") {
if (!with_dynamic_shape) {
VLOG(3)
<< "fused_preln_embedding_eltwise_layernorm should run on dynamic "
"shape mode.";
return false;
}
if (desc.Input("Ids").size() != desc.Input("Embs").size()) {
VLOG(3) << "The id and emb size of fused PrelnEmbEltwiseLayerNormOp "
"should be same ";
return false;
}
if (!desc.HasAttr("enable_int8")) {
VLOG(3) << "PrelnEmbEltwiseLayerNormOp must use int8 mode.";
return false;
}
}
if (op_type == "gelu") { if (op_type == "gelu") {
if (desc.Input("X").size() != 1) { if (desc.Input("X").size() != 1) {
VLOG(3) << "gelu op has only 1 input, but got " VLOG(3) << "gelu op has only 1 input, but got "
...@@ -1019,9 +1039,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1019,9 +1039,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
#if IS_TRT_VERSION_LT(7000)
if (desc.HasAttr("approximate")) { if (desc.HasAttr("approximate")) {
VLOG(3) << "approximate gelu op needs TensorRT 7.0 and after";
if (BOOST_GET_CONST(bool, desc.GetAttr("approximate"))) return false; if (BOOST_GET_CONST(bool, desc.GetAttr("approximate"))) return false;
} }
#endif
auto* block = desc.Block(); auto* block = desc.Block();
if (block == nullptr) { if (block == nullptr) {
...@@ -1030,6 +1053,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1030,6 +1053,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
"the pass."; "the pass.";
return false; return false;
} }
auto x_var_name = desc.Input("X")[0]; auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name); auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape(); const auto x_shape = x_var_desc->GetShape();
...@@ -1312,6 +1336,17 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1312,6 +1336,17 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
} }
if (op_type == "preln_skip_layernorm") {
if (!with_dynamic_shape) {
VLOG(3) << "the preln_skip_layernorm does not support static shape yet";
return false;
}
if (!desc.HasAttr("enable_int8")) {
VLOG(3) << "PrelnEmbEltwiseLayerNormOp must use int8 mode.";
return false;
}
}
if (op_type == "multihead_matmul") { if (op_type == "multihead_matmul") {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
VLOG(3) << "the multihead_matmul does not support static shape yet"; VLOG(3) << "the multihead_matmul does not support static shape yet";
......
...@@ -35,6 +35,36 @@ nvinfer1::Dims PoolPlugin::getOutputDimensions(int index, ...@@ -35,6 +35,36 @@ nvinfer1::Dims PoolPlugin::getOutputDimensions(int index,
return output_dims; return output_dims;
} }
size_t PoolPlugin::getSerializationSize() const TRT_NOEXCEPT {
return getBaseSerializationSize() + SerializedSize(ceil_mode_) +
SerializedSize(pool_type_) + SerializedSize(adaptive_) +
SerializedSize(exclusive_) + SerializedSize(ksize_) +
SerializedSize(strides_) + SerializedSize(paddings_) +
SerializedSize(real_paddings_) + SerializedSize(input_shape_) +
SerializedSize(output_shape_);
}
// TRT will call this func when we need to serialize the configuration of
// tensorrt.
void PoolPlugin::serialize(void *buffer) const TRT_NOEXCEPT {
serializeBase(buffer);
SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, pool_type_);
SerializeValue(&buffer, adaptive_);
SerializeValue(&buffer, exclusive_);
SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_);
SerializeValue(&buffer, real_paddings_);
SerializeValue(&buffer, input_shape_);
SerializeValue(&buffer, output_shape_);
}
PoolPlugin *PoolPlugin::clone() const TRT_NOEXCEPT {
return new PoolPlugin(ceil_mode_, pool_type_, adaptive_, exclusive_, ksize_,
strides_, paddings_, input_shape_, real_paddings_);
}
int PoolPlugin::enqueue(int batchSize, const void *const *inputs, int PoolPlugin::enqueue(int batchSize, const void *const *inputs,
#if IS_TRT_VERSION_LT(8000) #if IS_TRT_VERSION_LT(8000)
void **outputs, void *workspace, void **outputs, void *workspace,
...@@ -59,14 +89,15 @@ int PoolPlugin::enqueue(int batchSize, const void *const *inputs, ...@@ -59,14 +89,15 @@ int PoolPlugin::enqueue(int batchSize, const void *const *inputs,
paddle::operators::math::MaxPool<float>, float> paddle::operators::math::MaxPool<float>, float>
pool2d_forward; pool2d_forward;
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, pool2d_forward(idata, input_shape, output_shape, ksize_, strides_,
paddings_, true, adaptive_, odatas[0], stream, pool_process); paddings_, true, false, odatas[0], stream, pool_process);
} else if (pool_type_ == PoolType::avg) { } else if (pool_type_ == PoolType::avg) {
paddle::operators::math::AvgPool<float> pool_process; paddle::operators::math::AvgPool<float> pool_process;
paddle::operators::math::Pool2dDirectCUDAFunctor< paddle::operators::math::Pool2dDirectCUDAFunctor<
paddle::operators::math::AvgPool<float>, float> paddle::operators::math::AvgPool<float>, float>
pool2d_forward; pool2d_forward;
pool2d_forward(idata, input_shape, output_shape, ksize_, strides_, pool2d_forward(idata, input_shape, output_shape, ksize_, strides_,
paddings_, true, adaptive_, odatas[0], stream, pool_process); paddings_, exclusive_, adaptive_, odatas[0], stream,
pool_process);
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
...@@ -82,6 +113,7 @@ PoolPluginDynamic::PoolPluginDynamic(void const *serialData, ...@@ -82,6 +113,7 @@ PoolPluginDynamic::PoolPluginDynamic(void const *serialData,
DeserializeValue(&serialData, &serialLength, &pool_type); DeserializeValue(&serialData, &serialLength, &pool_type);
pool_type_ = std::string(pool_type); pool_type_ = std::string(pool_type);
DeserializeValue(&serialData, &serialLength, &adaptive_); DeserializeValue(&serialData, &serialLength, &adaptive_);
DeserializeValue(&serialData, &serialLength, &exclusive_);
DeserializeValue(&serialData, &serialLength, &ksize_); DeserializeValue(&serialData, &serialLength, &ksize_);
DeserializeValue(&serialData, &serialLength, &strides_); DeserializeValue(&serialData, &serialLength, &strides_);
DeserializeValue(&serialData, &serialLength, &paddings_); DeserializeValue(&serialData, &serialLength, &paddings_);
...@@ -90,21 +122,27 @@ PoolPluginDynamic::PoolPluginDynamic(void const *serialData, ...@@ -90,21 +122,27 @@ PoolPluginDynamic::PoolPluginDynamic(void const *serialData,
size_t PoolPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { size_t PoolPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
return SerializedSize(ceil_mode_) + SerializedSize(pool_type_.c_str()) + return SerializedSize(ceil_mode_) + SerializedSize(pool_type_.c_str()) +
SerializedSize(adaptive_) + SerializedSize(ksize_) + SerializedSize(adaptive_) + SerializedSize(exclusive_) +
SerializedSize(strides_) + SerializedSize(paddings_) + SerializedSize(ksize_) + SerializedSize(strides_) +
SerializedSize(is_global_); SerializedSize(paddings_) + SerializedSize(is_global_);
} }
void PoolPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { void PoolPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, ceil_mode_); SerializeValue(&buffer, ceil_mode_);
SerializeValue(&buffer, pool_type_.c_str()); SerializeValue(&buffer, pool_type_.c_str());
SerializeValue(&buffer, adaptive_); SerializeValue(&buffer, adaptive_);
SerializeValue(&buffer, exclusive_);
SerializeValue(&buffer, ksize_); SerializeValue(&buffer, ksize_);
SerializeValue(&buffer, strides_); SerializeValue(&buffer, strides_);
SerializeValue(&buffer, paddings_); SerializeValue(&buffer, paddings_);
SerializeValue(&buffer, is_global_); SerializeValue(&buffer, is_global_);
} }
nvinfer1::IPluginV2DynamicExt *PoolPluginDynamic::clone() const TRT_NOEXCEPT {
return new PoolPluginDynamic(ceil_mode_, pool_type_, adaptive_, exclusive_,
ksize_, strides_, paddings_, is_global_);
}
nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT {
...@@ -117,11 +155,14 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions( ...@@ -117,11 +155,14 @@ nvinfer1::DimsExprs PoolPluginDynamic::getOutputDimensions(
platform::errors::InvalidArgument("The channel dimension should be " platform::errors::InvalidArgument("The channel dimension should be "
"static, but we found it's dynamic.")); "static, but we found it's dynamic."));
nvinfer1::DimsExprs output(inputs[0]); nvinfer1::DimsExprs output(inputs[0]);
if (is_global_) { if (is_global_ && !adaptive_) {
output.d[2] = expr_builder.constant(1); output.d[2] = expr_builder.constant(1);
output.d[3] = expr_builder.constant(1); output.d[3] = expr_builder.constant(1);
return output; return output;
} }
if (is_global_ && adaptive_) {
return inputs[0];
}
if (adaptive_) { if (adaptive_) {
output.d[2] = expr_builder.constant(ksize_[0]); output.d[2] = expr_builder.constant(ksize_[0]);
output.d[3] = expr_builder.constant(ksize_[1]); output.d[3] = expr_builder.constant(ksize_[1]);
...@@ -245,6 +286,10 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -245,6 +286,10 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
output_shape[2] = data_dim[0]; output_shape[2] = data_dim[0];
output_shape[3] = data_dim[1]; output_shape[3] = data_dim[1];
} }
if (adaptive_) {
output_shape[2] = h;
output_shape[3] = w;
}
if (pool_type_ == "max") { if (pool_type_ == "max") {
paddle::operators::math::MaxPool<float> pool_process; paddle::operators::math::MaxPool<float> pool_process;
...@@ -252,14 +297,14 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -252,14 +297,14 @@ int PoolPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
paddle::operators::math::MaxPool<float>, float> paddle::operators::math::MaxPool<float>, float>
pool2d_forward; pool2d_forward;
pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings,
true, adaptive_, output, stream, pool_process); true, false, output, stream, pool_process);
} else if (pool_type_ == "avg") { } else if (pool_type_ == "avg") {
paddle::operators::math::AvgPool<float> pool_process; paddle::operators::math::AvgPool<float> pool_process;
paddle::operators::math::Pool2dDirectCUDAFunctor< paddle::operators::math::Pool2dDirectCUDAFunctor<
paddle::operators::math::AvgPool<float>, float> paddle::operators::math::AvgPool<float>, float>
pool2d_forward; pool2d_forward;
pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings, pool2d_forward(input, input_shape, output_shape, ksize, strides_, paddings,
true, adaptive_, output, stream, pool_process); exclusive_, adaptive_, output, stream, pool_process);
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
......
...@@ -58,6 +58,11 @@ else () ...@@ -58,6 +58,11 @@ else ()
set(AllocatorFacadeDeps) set(AllocatorFacadeDeps)
endif() endif()
if (WITH_CUSTOM_DEVICE)
cc_library(custom_allocator SRCS custom_allocator.cc DEPS allocator device_manager)
set(AllocatorFacadeDeps ${AllocatorFacadeDeps} custom_allocator)
endif()
if (WITH_GPU) if (WITH_GPU)
nv_test(best_fit_allocator_test nv_test(best_fit_allocator_test
SRCS best_fit_allocator_test.cc SRCS best_fit_allocator_test.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/memory/allocation/custom_allocator.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace memory {
namespace allocation {
bool CustomAllocator::IsAllocThreadSafe() const { return true; }
void CustomAllocator::FreeImpl(pten::Allocation* allocation) {
PADDLE_ENFORCE_EQ(
allocation->place(), place_,
platform::errors::PermissionDenied("CustomDevice memory is "
"freed in incorrect device. "
"This may be a bug"));
delete allocation;
}
pten::Allocation* CustomAllocator::AllocateImpl(size_t size) {
std::call_once(once_flag_,
[this] { platform::DeviceManager::SetDevice(place_); });
void* ptr =
platform::DeviceManager::GetDeviceWithPlace(place_)->MemoryAllocate(size);
if (LIKELY(ptr)) {
return new Allocation(ptr, size, place_);
}
size_t avail, total;
platform::DeviceManager::MemoryStats(place_, &total, &avail);
auto dev_type = platform::PlaceHelper::GetDeviceType(place_);
auto dev_id = platform::PlaceHelper::GetDeviceId(place_);
PADDLE_THROW_BAD_ALLOC(platform::errors::ResourceExhausted(
"\n\nOut of memory error on %s:%d. "
"Cannot allocate %s memory on %s:%d, "
"available memory is only %s.\n\n"
"Please check whether there is any other process using %s:%d.\n"
"1. If yes, please stop them, or start PaddlePaddle on another %s.\n"
"2. If no, please decrease the batch size of your model.\n\n",
dev_type, dev_id, string::HumanReadableSize(size), dev_type, dev_id,
string::HumanReadableSize(avail), dev_type, dev_id, dev_type));
}
} // namespace allocation
} // namespace memory
} // namespace paddle
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册