diff --git a/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..46275ed4d99a0d55fd363061ab082e7d60e449d4 --- /dev/null +++ b/paddle/fluid/lite/core/mir/runtime_context_assign_pass.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/pass.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +class RuntimeContextAssignPass : public InstructionPass { + public: + RuntimeContextAssignPass() { +#ifdef LITE_WITH_CUDA + InitCudaBlas(); +#endif + } + + void Apply(std::unique_ptr& graph) override { + for (auto& node : graph->mutable_nodes()) { + if (!node.IsInstruct()) continue; + + auto& inst = node.AsInstruct(); + + switch (inst.picked_kernel().target()) { + case TARGET(kHost): + case TARGET(kX86): + inst.picked_kernel().SetContext(NewHostContext()); + break; + case TARGET(kCUDA): + inst.picked_kernel().SetContext(NewCudaContext()); + break; + default: + LOG(FATAL) << "unsupported target " + << TargetToStr(inst.picked_kernel().target()); + } + } + } + + std::unique_ptr NewHostContext() { + std::unique_ptr ctx(new KernelContext); + ctx->AsX86Context(); + // Some initialization here. + return ctx; + } + + std::unique_ptr NewCudaContext() { + std::unique_ptr ctx(new KernelContext); + auto& cuda = ctx->AsCudaContext(); + // Some initialization here. + CHECK(cublas_fp32_) << "cublas_fp32 should be set first"; + cuda.blas_fp32 = cublas_fp32_; + return ctx; + } + + void InitCudaBlas() { + cublas_fp32_ = std::make_shared>(); + } + + private: + std::shared_ptr> cublas_fp32_; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(runtime_context_assign_pass, + paddle::lite::mir::RuntimeContextAssignPass); diff --git a/paddle/fluid/lite/core/types_test.cc b/paddle/fluid/lite/core/types_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6cb4ef77a5dd1c89d9460d81a79941581eb7b863 --- /dev/null +++ b/paddle/fluid/lite/core/types_test.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/types.h" +#include + +namespace paddle { +namespace lite { +namespace core { + +TEST(KernelPickFactor, Default) { + KernelPickFactor factor; + ASSERT_FALSE(factor.IsTargetConsidered()); + ASSERT_FALSE(factor.IsPrecisionConsidered()); + ASSERT_FALSE(factor.IsDataLayoutConsidered()); +} + +TEST(KernelPickFactor, Set) { + KernelPickFactor factor; + factor.ConsiderTarget(); + ASSERT_TRUE(factor.IsTargetConsidered()); + factor.ConsiderPrecision(); + ASSERT_TRUE(factor.IsPrecisionConsidered()); + factor.ConsiderDataLayout(); + ASSERT_TRUE(factor.IsDataLayoutConsidered()); + + LOG(INFO) << "factor " << factor; +} + +} // namespace core +} // namespace lite +} // namespace paddle