未验证 提交 7aed7eb5 编写于 作者: Q QI JUN 提交者: GitHub

cache memory in local scope (#7058)

* add KernelTypeToString interface

* cache memory in local scope

* fix typo

* refine trans logic
上级 b775b6cb
...@@ -27,7 +27,7 @@ limitations under the License. */ ...@@ -27,7 +27,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using DataTransformFN = using DataTransformFn =
std::function<void(const std::vector<platform::DeviceContext*> ctx, std::function<void(const std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out)>; const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>; using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
...@@ -47,7 +47,7 @@ struct KernelTypePairHash { ...@@ -47,7 +47,7 @@ struct KernelTypePairHash {
}; };
using DataTransformMap = using DataTransformMap =
std::unordered_map<KernelTypePair, DataTransformFN, KernelTypePairHash>; std::unordered_map<KernelTypePair, DataTransformFn, KernelTypePairHash>;
class DataTransformFnMap { class DataTransformFnMap {
public: public:
...@@ -58,25 +58,25 @@ class DataTransformFnMap { ...@@ -58,25 +58,25 @@ class DataTransformFnMap {
} }
void Insert(const OpKernelType& left, const OpKernelType& right, void Insert(const OpKernelType& left, const OpKernelType& right,
const DataTransformFN& data_tranform_fn) { const DataTransformFn& data_tranform_fn) {
Insert(std::make_pair(left, right), data_tranform_fn); Insert(std::make_pair(left, right), data_tranform_fn);
} }
void Insert(const KernelTypePair& kernel_type_pair, void Insert(const KernelTypePair& kernel_type_pair,
const DataTransformFN& data_tranform_fn) { const DataTransformFn& data_tranform_fn) {
PADDLE_ENFORCE(!Has(kernel_type_pair), PADDLE_ENFORCE(!Has(kernel_type_pair),
"KernelTypePair %s has been registered", ""); "KernelTypePair %s has been registered", "");
map_.insert({kernel_type_pair, data_tranform_fn}); map_.insert({kernel_type_pair, data_tranform_fn});
} }
const DataTransformFN& Get(const KernelTypePair& key_pair) const { const DataTransformFn& Get(const KernelTypePair& key_pair) const {
auto data_transformer = GetNullable(key_pair); auto data_transformer = GetNullable(key_pair);
PADDLE_ENFORCE_NOT_NULL(data_transformer, PADDLE_ENFORCE_NOT_NULL(data_transformer,
"DataTransformFN should not be NULL"); "DataTransformFn should not be NULL");
return *data_transformer; return *data_transformer;
} }
const DataTransformFN* GetNullable(const KernelTypePair& key_pair) const { const DataTransformFn* GetNullable(const KernelTypePair& key_pair) const {
auto it = map_.find(key_pair); auto it = map_.find(key_pair);
if (it == map_.end()) { if (it == map_.end()) {
return nullptr; return nullptr;
......
...@@ -68,6 +68,8 @@ struct OpKernelType { ...@@ -68,6 +68,8 @@ struct OpKernelType {
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ && data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_; library_type_ == o.library_type_;
} }
bool operator!=(const OpKernelType& o) const { return !(*this == o); }
}; };
inline std::ostream& operator<<(std::ostream& os, inline std::ostream& operator<<(std::ostream& os,
...@@ -78,5 +80,11 @@ inline std::ostream& operator<<(std::ostream& os, ...@@ -78,5 +80,11 @@ inline std::ostream& operator<<(std::ostream& os,
return os; return os;
} }
inline std::string KernelTypeToString(const OpKernelType& kernel_key) {
std::ostringstream stream;
stream << kernel_key;
return stream.str();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,10 +26,8 @@ TEST(OpKernelType, ToString) { ...@@ -26,10 +26,8 @@ TEST(OpKernelType, ToString) {
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW, OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN); LibraryType::kCUDNN);
std::ostringstream stream;
stream << op_kernel_type;
ASSERT_EQ( ASSERT_EQ(
stream.str(), paddle::framework::KernelTypeToString(op_kernel_type),
"data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]"); "data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
} }
......
...@@ -413,37 +413,51 @@ void OperatorWithKernel::Run(const Scope& scope, ...@@ -413,37 +413,51 @@ void OperatorWithKernel::Run(const Scope& scope,
} }
if (actual_kernel_key == expected_kernel_key) { if (actual_kernel_key == expected_kernel_key) {
kernel_iter->second->Compute(ctx); PADDLE_ENFORCE_EQ(actual_kernel_key.place_, expected_kernel_key.place_,
"Currently, model parallelism is only supported between "
"CPU and other devices. For example, multi-GPU model "
"parallelism will failed.");
} else { } else {
Scope& op_scope = scope.NewScope(); const DataTransformFn* trans_fun =
auto input_vars = this->InputVars(); DataTransformFnMap::Instance().GetNullable(
for (auto var_name : input_vars) { std::make_pair(actual_kernel_key, expected_kernel_key));
op_scope.Var(var_name); if (trans_fun) {
} auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr; // filter vars that has been transformed
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx}; std::vector<std::string> need_trans;
for (auto var_name : input_vars) {
auto var_name_trans =
var_name + framework::KernelTypeToString(expected_kernel_key);
if (!scope.FindVar(var_name_trans)) {
const_cast<Scope&>(scope).Var(var_name_trans);
need_trans.push_back(var_name);
}
}
// TODO(qijun) get appropriate DataTransformFN from global map if (!need_trans.empty()) {
framework::DataTransformFN trans_fun = nullptr; // TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
// Wait for transform starting // Wait for transform starting
dev_ctx->Wait(); dev_ctx->Wait();
for (auto var_name : input_vars) { for (auto var_name : need_trans) {
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)), (*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
op_scope.FindVar(var_name)); scope.FindVar(var_name + framework::KernelTypeToString(
} expected_kernel_key)));
// Wait for data transform finishing }
for (auto ctx : trans_dev_ctx_vec) { // Wait for data transform finishing
ctx->Wait(); for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}
}
} }
// Create a new ExecutionContext
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
kernel_iter->second->Compute(op_ctx);
} }
kernel_iter->second->Compute(ctx);
} }
OpKernelType OperatorWithKernel::GetActualKernelType( OpKernelType OperatorWithKernel::GetActualKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册