diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 9a0348871b050278da2ad07ac6992188a702da42..385921f704cf48c6c6a463c6800b4ec992f73084 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -921,7 +921,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, OpKernelMap& kernels = kernels_iter->second; auto expected_kernel_key = this->GetExpectedKernelType( - ExecutionContext(*this, scope, *dev_ctx, ctx)); + ExecutionContext(*this, scope, *dev_ctx, ctx, nullptr)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); @@ -940,6 +940,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, KernelTypeToString(expected_kernel_key)); } + auto config_iter = kernel_configs_map_.find(expected_kernel_key); + std::vector* kernel_configs = nullptr; + if (config_iter != kernel_configs_map_.end()) { + kernel_configs = &(config_iter->second); + } + // do data transformScope &transfer_scope; std::vector transfered_inplace_vars; auto* transfer_scope = @@ -957,7 +963,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, this->InferShape(&infer_shape_ctx); // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext // not Scope. Imperative mode only pass inputs and get outputs. - kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx)); + kernel_iter->second( + ExecutionContext(*this, exec_scope, *dev_ctx, ctx, kernel_configs)); if (!transfered_inplace_vars.empty()) { // there is inplace variable has been transfered. diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index e33214b44bb5d8ea5eb32d442d597a369c198bdd..b8d2c1eaf2ca633af7b819772a832213b11c7b54 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -184,12 +184,125 @@ class OperatorBase { const platform::Place& place) const = 0; }; +template +class AlgorithmsCache { + public: + AlgorithmsCache() : search_times_(0) { hash_.clear(); } + // Caches the best algorithm for a given + // combination of tensor dimensions & compute data type. + TAlgorithm GetAlgorithm( + const std::vector& dims1, const std::vector& dims2, + const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, + int algorithmFlags, // can set for different data type + std::function gen_func); + + TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags, + std::function gen_func); + + private: + std::unordered_map hash_; + std::mutex mutex_; + + int search_times_; +}; + +template +TAlgorithm framework::AlgorithmsCache::GetAlgorithm( + const std::vector& dims1, const std::vector& dims2, + const std::vector& strides, const std::vector& paddings, + const std::vector& dilations, int algorithmFlags, + std::function gen_func) { + std::lock_guard lock(mutex_); + int64_t seed = 0; + // Hash all of the inputs, use to try and look up a previously + // discovered algorithm, or fall back to generating a new one. + std::hash hashFn; + // do hash like boost + // https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x + for (const auto num : dims1) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + + for (const auto num : dims2) { + seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1; + } + + for (const auto num : strides) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 2; + } + + for (const auto num : paddings) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 3; + } + + for (const auto num : dilations) { + seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + + (seed >> 2) + 4; + } + + seed ^= hashFn(static_cast(algorithmFlags)) + 0x9e3779b9 + + (seed << 6) + (seed >> 2) + 5; + + if (seed == 0) return gen_func(); + + if (hash_.find(seed) == hash_.end()) { + TAlgorithm value = gen_func(); + hash_[seed] = value; + } + return hash_[seed]; +} + +template +TAlgorithm AlgorithmsCache::GetAlgorithm( + int64_t area, int search_times, int algorithmFlags, + std::function gen_func) { + if (hash_.find(area) != hash_.end()) { + return hash_[area]; + } + if (search_times_ < search_times) { + auto algo = gen_func(); + hash_[area] = algo; + ++search_times_; + return algo; + } + TAlgorithm algo; + int64_t min = static_cast(INT_MAX); + for (const auto& m : hash_) { + if (m.first < min) { + min = m.first; + algo = m.second; + } + } + return algo; +} + +#ifdef PADDLE_WITH_CUDA +using KernelConfig = boost::variant< + std::shared_ptr>, + std::shared_ptr>, + std::shared_ptr>>; +#else +using KernelConfig = boost::variant; +#endif + +using OpKernelConfigsMap = + std::unordered_map, + OpKernelType::Hash>; + class ExecutionContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, const platform::DeviceContext& device_context, - const RuntimeContext& ctx) - : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {} + const RuntimeContext& ctx, + std::vector* configs) + : op_(op), + scope_(scope), + device_context_(device_context), + ctx_(ctx), + kernel_configs_(configs) {} const OperatorBase& op() const { return op_; } @@ -398,11 +511,20 @@ class ExecutionContext { return temp_tensor; } + template + T& GetKernelConfig(int idx) const { + PADDLE_ENFORCE(kernel_configs_ && kernel_configs_->size() > idx, + "%s selected kernel doesn't have kernel config %lu <= %d", + op_.Type().c_str(), kernel_configs_->size(), idx); + return *boost::get>(kernel_configs_->at(idx)); + } + private: const OperatorBase& op_; const Scope& scope_; const platform::DeviceContext& device_context_; const RuntimeContext& ctx_; + mutable std::vector* kernel_configs_; }; template <> @@ -508,6 +630,9 @@ class OperatorWithKernel : public OperatorBase { void TransferInplaceVarsBack(const Scope& scope, const std::vector& inplace_vars, const Scope& exec_scope) const; + + protected: + mutable OpKernelConfigsMap kernel_configs_map_; }; extern bool OpSupportGPU(const std::string& op_type); diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 733542e4972b16a71f9e76c3076b424b7a901066..fa77b96a7bdfa28ed982db022e8e5ecaef0b443c 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -50,8 +50,6 @@ class Scope; } // namespace framework namespace operators { -template -class AlgorithmsCache; class CudnnRNNCache; @@ -144,9 +142,6 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< #ifndef _WIN32 ncclUniqueId, platform::Communicator, #endif - operators::AlgorithmsCache, - operators::AlgorithmsCache, - operators::AlgorithmsCache, operators::CudnnRNNCache, #endif int, float>; diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 8f20f0c06e043ddc629e47c6e49280c5467b0e20..aff5cf24be7c41cf58929069768d4fdb34386ae6 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -249,7 +249,8 @@ std::map> OpBase::ApplyGrad() { framework::Scope scope; PreparedOp p = PreparedOp::Prepare(ctx, *op_kernel, place_); p.op.RuntimeInferShape(scope, place_, ctx); - p.func(framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx)); + p.func( + framework::ExecutionContext(p.op, scope, *p.dev_ctx, p.ctx, nullptr)); } } diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 78205486c5534ac0c61cc6d545bdafa4dfc95695..2dbc1b0f9690587868d0a0e8602a0d6332e2806b 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -64,8 +64,9 @@ class PreparedOp { framework::OperatorWithKernel::OpKernelMap& kernels = kernels_iter->second; - auto expected_kernel_key = op.GetExpectedKernelType( - framework::ExecutionContext(op, framework::Scope(), *dev_ctx, ctx)); + auto expected_kernel_key = + op.GetExpectedKernelType(framework::ExecutionContext( + op, framework::Scope(), *dev_ctx, ctx, nullptr)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; auto kernel_iter = kernels.find(expected_kernel_key); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index bc39d11ba00a6a7c386162a1f9201c6f992c8692..1982fdb1c79b1eb1547835d1cfaac64c2f7fb5ac 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -139,7 +139,7 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_); prepared_op.op.RuntimeInferShape(scope, op->place_, ctx); prepared_op.func(framework::ExecutionContext( - prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx)); + prepared_op.op, scope, *prepared_op.dev_ctx, prepared_op.ctx, nullptr)); if (!stop_gradient) { std::unique_ptr> grad_to_var( diff --git a/paddle/fluid/operators/beam_search_decode_op.cc b/paddle/fluid/operators/beam_search_decode_op.cc index 7f2bde55c98277b9fd4b3374657001c42d673d43..cf78c83297a87beb08a8b8e6e4b182f03f1909d3 100644 --- a/paddle/fluid/operators/beam_search_decode_op.cc +++ b/paddle/fluid/operators/beam_search_decode_op.cc @@ -123,7 +123,7 @@ class BeamSearchDecodeOp : public framework::OperatorBase { auto& dev_ctx = *pool.Get(dev_place); framework::RuntimeContext run_ctx(Inputs(), Outputs(), scope); - framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx); + framework::ExecutionContext ctx(*this, scope, dev_ctx, run_ctx, nullptr); const LoDTensorArray* ids = ctx.Input("Ids"); const LoDTensorArray* scores = ctx.Input("Scores"); diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc index f5208e7a601f4dd33b486e5840178022f66431e5..9e5ccd928e9d6012c1da3baa17521dcac0c8ff2f 100644 --- a/paddle/fluid/operators/conv_cudnn_op.cu.cc +++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc @@ -42,6 +42,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using DataLayout = platform::DataLayout; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; +using framework::AlgorithmsCache; template class CUDNNConvOpKernel : public framework::OpKernel { @@ -169,18 +170,8 @@ class CUDNNConvOpKernel : public framework::OpKernel { workspace_size_limit, &algo)); VLOG(3) << "cuDNN forward algo " << algo; } else if (exhaustive_search && (!half_float)) { - AlgorithmsCache* algo_cache = nullptr; - if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { - algo_cache = - ctx.scope() - .FindVar(kCUDNNFwdAlgoCache) - ->GetMutable>(); - } else { - algo_cache = - const_cast(ctx.scope()) - .Var(kCUDNNFwdAlgoCache) - ->GetMutable>(); - } + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(0); cudnn_workspace = ctx.AllocateTmpTensor( framework::make_ddim( @@ -188,7 +179,7 @@ class CUDNNConvOpKernel : public framework::OpKernel { dev_ctx); cudnn_workspace_ptr = static_cast(cudnn_workspace.data()); - algo = algo_cache->GetAlgorithm( + algo = algo_cache.GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array @@ -382,22 +373,11 @@ class CUDNNConvGradOpKernel : public framework::OpKernel { if (input_grad) { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); if (exhaustive_search) { - AlgorithmsCache* data_algo_cache; - if (ctx.scope().FindVar(kCUDNNBwdDataAlgoCache)) { - data_algo_cache = - ctx.scope() - .FindVar(kCUDNNBwdDataAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } else { - data_algo_cache = - const_cast(ctx.scope()) - .Var(kCUDNNBwdDataAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } - - data_algo = data_algo_cache->GetAlgorithm( + AlgorithmsCache& data_algo_cache = + ctx.GetKernelConfig>( + 0); + + data_algo = data_algo_cache.GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array { if (filter_grad) { T* filter_grad_data = filter_grad->mutable_data(ctx.GetPlace()); if (exhaustive_search) { - AlgorithmsCache* f_algo_cache; - if (ctx.scope().FindVar(kCUDNNBwdFilterAlgoCache)) { - f_algo_cache = - ctx.scope() - .FindVar(kCUDNNBwdFilterAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } else { - f_algo_cache = - const_cast(ctx.scope()) - .Var(kCUDNNBwdFilterAlgoCache) - ->GetMutable< - AlgorithmsCache>(); - } - - filter_algo = f_algo_cache->GetAlgorithm( + AlgorithmsCache& f_algo_cache = + ctx.GetKernelConfig< + AlgorithmsCache>(1); + + filter_algo = f_algo_cache.GetAlgorithm( x_dims, f_dims, strides, paddings, dilations, 0, [&]() { int returned_algo_count; std::array #include #include +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/cudnn_helper.h" DECLARE_uint64(conv_workspace_size_limit); @@ -46,100 +47,5 @@ static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4; static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5; #endif -template -class AlgorithmsCache { - public: - AlgorithmsCache() : search_times_(0) { hash_.clear(); } - // Caches the best algorithm for a given - // combination of tensor dimensions & compute data type. - TAlgorithm GetAlgorithm( - const std::vector& dims1, const std::vector& dims2, - const std::vector& strides, const std::vector& paddings, - const std::vector& dilations, - int algorithmFlags, // can set for different data type - std::function gen_func); - - TAlgorithm GetAlgorithm(int64_t area, int search_times, int algorithmFlags, - std::function gen_func); - - private: - std::unordered_map hash_; - std::mutex mutex_; - - int search_times_; -}; - -template -TAlgorithm AlgorithmsCache::GetAlgorithm( - const std::vector& dims1, const std::vector& dims2, - const std::vector& strides, const std::vector& paddings, - const std::vector& dilations, int algorithmFlags, - std::function gen_func) { - std::lock_guard lock(mutex_); - int64_t seed = 0; - // Hash all of the inputs, use to try and look up a previously - // discovered algorithm, or fall back to generating a new one. - std::hash hashFn; - // do hash like boost - // https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x - for (const auto num : dims1) { - seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2); - } - - for (const auto num : dims2) { - seed ^= hashFn(num) + 0x9e3779b9 + (seed << 6) + (seed >> 2) + 1; - } - - for (const auto num : strides) { - seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + - (seed >> 2) + 2; - } - - for (const auto num : paddings) { - seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + - (seed >> 2) + 3; - } - - for (const auto num : dilations) { - seed ^= hashFn(static_cast(num)) + 0x9e3779b9 + (seed << 6) + - (seed >> 2) + 4; - } - - seed ^= hashFn(static_cast(algorithmFlags)) + 0x9e3779b9 + - (seed << 6) + (seed >> 2) + 5; - - if (seed == 0) return gen_func(); - - if (hash_.find(seed) == hash_.end()) { - TAlgorithm value = gen_func(); - hash_[seed] = value; - } - return hash_[seed]; -} - -template -TAlgorithm AlgorithmsCache::GetAlgorithm( - int64_t area, int search_times, int algorithmFlags, - std::function gen_func) { - if (hash_.find(area) != hash_.end()) { - return hash_[area]; - } - if (search_times_ < search_times) { - auto algo = gen_func(); - hash_[area] = algo; - ++search_times_; - return algo; - } - TAlgorithm algo; - int64_t min = static_cast(INT_MAX); - for (const auto& m : hash_) { - if (m.first < min) { - min = m.first; - algo = m.second; - } - } - return algo; -} - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/conv_fusion_op.cu.cc b/paddle/fluid/operators/conv_fusion_op.cu.cc index d8b997cca613f660046106512fc03bf55f9b992d..705ce41a3ff869d1ac1bfe89790d55e964940db2 100644 --- a/paddle/fluid/operators/conv_fusion_op.cu.cc +++ b/paddle/fluid/operators/conv_fusion_op.cu.cc @@ -30,6 +30,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; using DataLayout = platform::DataLayout; +using framework::AlgorithmsCache; + template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; @@ -139,38 +141,23 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { } return fwd_perf_stat[0].algo; }; - AlgorithmsCache* algo_cache = nullptr; + AlgorithmsCache& algo_cache = + ctx.GetKernelConfig>(0); int search_times = ctx.Attr("search_times"); search_times = std::max( static_cast(FLAGS_cudnn_exhaustive_search_times), search_times); + // TODO(dangqingqing): Unify this if-else. if (search_times > 0) { // The searched algo will be cached by `search_times` times for // different input dimension. For other dimensions, select the algo // of closest area. - auto var_name = ctx.Inputs("AlgoCache")[0]; - algo_cache = - ctx.scope() - .FindVar(var_name) - ->GetMutable>(); - algo = algo_cache->GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0, - search_func); + algo = algo_cache.GetAlgorithm(x_dims[2] * x_dims[3], search_times, 0, + search_func); } else { // Cache searched algo in Var(kCUDNNFwdAlgoCache). // all conv ops use the same kCUDNNFwdAlgoCache variable. - if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { - algo_cache = - ctx.scope() - .FindVar(kCUDNNFwdAlgoCache) - ->GetMutable>(); - } else { - // TODO(qingqing) remove const_cast - algo_cache = - const_cast(ctx.scope().parent()) - ->Var(kCUDNNFwdAlgoCache) - ->GetMutable>(); - } - algo = algo_cache->GetAlgorithm(x_dims, f_dims, strides, paddings, - dilations, 0, search_func); + algo = algo_cache.GetAlgorithm(x_dims, f_dims, strides, paddings, + dilations, 0, search_func); } VLOG(3) << "choose algo " << algo; } diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index fd9f156d070bdb1990a2fc9c63305933050e5524..a37c8d3ccd9c3bb8fae8a5f198bc4db714301b68 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_helper.h" #endif #ifdef PADDLE_WITH_MKLDNN @@ -109,8 +110,20 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( "float16 can only be used when CUDNN is used"); } - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, - library, customized_type_value); + auto type = framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library, customized_type_value); +#ifdef PADDLE_WITH_CUDA + std::vector& configs = kernel_configs_map_[type]; + // TODO(dangqingqing): Currently conv_fusion_op use cudnn but sets use_cudnn + // to false. It should be fixed and then here should only create if library + // is kCUDNN. + if (configs.empty()) { + std::shared_ptr> p( + new framework::AlgorithmsCache()); + configs.push_back(p); + } +#endif + return type; } void Conv2DOpMaker::Make() { @@ -410,9 +423,25 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } #endif - return framework::OpKernelType(ctx.Input("Input")->type(), - ctx.GetPlace(), layout_, library_, - customized_type_value); + auto type = framework::OpKernelType(ctx.Input("Input")->type(), + ctx.GetPlace(), layout_, library_, + customized_type_value); +#ifdef PADDLE_WITH_CUDA + if (library_ == framework::LibraryType::kCUDNN) { + std::vector& configs = kernel_configs_map_[type]; + if (configs.empty()) { + std::shared_ptr> + p(new framework::AlgorithmsCache()); + configs.push_back(p); + + std::shared_ptr< + framework::AlgorithmsCache> + p2(new framework::AlgorithmsCache()); + configs.push_back(p2); + } + } +#endif + return type; } class Conv2dGradMaker : public framework::SingleGradOpDescMaker { diff --git a/paddle/fluid/platform/temporary_allocator_test.cc b/paddle/fluid/platform/temporary_allocator_test.cc index 3879cd540017ea22b0cf4eee794a172e56716b74..6dae84f016e5db8007b4a4b4df2b5ed7f5cb4f19 100644 --- a/paddle/fluid/platform/temporary_allocator_test.cc +++ b/paddle/fluid/platform/temporary_allocator_test.cc @@ -141,7 +141,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = static_cast(pool.Get(cpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); int numel = memory_size / sizeof(float); framework::Tensor tensor = @@ -156,7 +156,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = static_cast(pool.Get(gpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); int numel = memory_size / sizeof(float); framework::Tensor tensor = ctx.AllocateTmpTensor( @@ -179,7 +179,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = static_cast(pool.Get(cpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); int numel = memory_size / sizeof(float); framework::Tensor out_side_tensor; @@ -200,7 +200,7 @@ TEST(temporary_allocator, create_tensor_with_allocationptr2) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = static_cast(pool.Get(gpu_place)); - framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx); + framework::ExecutionContext ctx(op, scope, *dev_ctx, run_ctx, nullptr); size_t memory_size = 500; int numel = memory_size / sizeof(float); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 15367c724e5304fed78ef58f8a27932e1d6de318..dd0deb02340a1c11bd8bbf3cf09224956270188d 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -723,7 +723,6 @@ class Operator(object): self._update_desc_attr(attr_name, attr_val) self.desc.check_attrs() - if self._has_kernel(type): self.desc.infer_var_type(self.block.desc) self.desc.infer_shape(self.block.desc)