diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 968fefcfafb53f697839b3c2f9b81b700ae1726c..7436e8c228db2caeb1421f8d78ddcf55f00deee4 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc DEPS glog) cc_test(scope_test SRCS scope_test.cc DEPS scope) +cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto) +cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc @@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog + shape_inference data_transform) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) @@ -65,6 +68,3 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) - -cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto) -cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index c83c08ba5cee8d26bd9b951b53763801f99124cc..73f894a3e20ab779f8607e63a67139b0e8cce79a 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -32,17 +32,16 @@ using DataTransformFN = const Variable& in, Variable* out)>; using KernelTypePair = std::pair; -static void hash_combine(std::size_t& seed, const OpKernelType& t) { - OpKernelType::Hash kernel_type_hasher; - seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - struct KernelTypePairHash { + static void HashCombine(const OpKernelType& t, std::size_t* seed) { + OpKernelType::Hash kernel_type_hasher; + (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + } + size_t operator()(const KernelTypePair& kernel_pair) const { std::size_t seed = 0; - hash_combine(seed, kernel_pair.first); - hash_combine(seed, kernel_pair.second); - + HashCombine(kernel_pair.first, &seed); + HashCombine(kernel_pair.second, &seed); return seed; } }; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 66840a2e037e7ca0fd1eacc64421865b170b47f8..886f73e7b81c35cac573bd041e6462eb2111bf85 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/framework/data_transform.h" #include "paddle/framework/executor.h" #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/operator.h" @@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope, expected_kernel_key); } - kernel_iter->second->Compute(ctx); + if (actual_kernel_key == expected_kernel_key) { + kernel_iter->second->Compute(ctx); + } else { + Scope& op_scope = scope.NewScope(); + auto input_vars = this->InputVars(); + for (auto var_name : input_vars) { + op_scope.Var(var_name); + } + + // TODO(qijun) get appropriate DeviceContext from DeviceContext pool + platform::DeviceContext* trans_dev_ctx = nullptr; + std::vector trans_dev_ctx_vec{trans_dev_ctx}; + + // TODO(qijun) get appropriate DataTransformFN from global map + framework::DataTransformFN trans_fun = nullptr; + + // Wait for transform starting + dev_ctx->Wait(); + + for (auto var_name : input_vars) { + trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)), + op_scope.FindVar(var_name)); + } + // Wait for data transform finishing + for (auto ctx : trans_dev_ctx_vec) { + ctx->Wait(); + } + + // Create a new ExecutionContext + ExecutionContext op_ctx(*this, op_scope, *dev_ctx); + kernel_iter->second->Compute(op_ctx); + } } OpKernelType OperatorWithKernel::GetActualKernelType(