提交 66cf184f 编写于 作者: 叶剑武

Merge branch 'scratch-image-bug' into 'master'

Bug: Replace OPENCLRuntime with GPURuntime in GPUDevice.

See merge request !904
......@@ -33,8 +33,8 @@ CPURuntime *CPUDevice::cpu_runtime() {
}
#ifdef MACE_ENABLE_OPENCL
OpenCLRuntime *CPUDevice::opencl_runtime() {
LOG(FATAL) << "CPU device should not call OpenCL Runtime";
GPURuntime *CPUDevice::gpu_runtime() {
LOG(FATAL) << "CPU device should not call GPU Runtime";
return nullptr;
}
#endif
......
......@@ -21,7 +21,7 @@
#include "mace/core/allocator.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/runtime/opencl/gpu_runtime.h"
#endif
namespace mace {
......@@ -33,7 +33,7 @@ class Device {
virtual ~Device() {}
#ifdef MACE_ENABLE_OPENCL
virtual OpenCLRuntime *opencl_runtime() = 0;
virtual GPURuntime *gpu_runtime() = 0;
#endif // MACE_ENABLE_OPENCL
virtual CPURuntime *cpu_runtime() = 0;
......@@ -50,7 +50,7 @@ class CPUDevice : public Device {
virtual ~CPUDevice();
#ifdef MACE_ENABLE_OPENCL
OpenCLRuntime *opencl_runtime() override;
GPURuntime *gpu_runtime() override;
#endif
CPURuntime *cpu_runtime() override;
......
......@@ -30,12 +30,13 @@ GPUDevice::GPUDevice(std::shared_ptr<Tuner<uint32_t>> tuner,
runtime_(new OpenCLRuntime(opencl_cache_storage, priority, perf,
opencl_binary_storage, tuner)),
allocator_(new OpenCLAllocator(runtime_.get())),
scratch_buffer_(new ScratchBuffer(allocator_.get())) {}
scratch_buffer_(new ScratchBuffer(allocator_.get())),
gpu_runtime_(new GPURuntime(runtime_.get())) {}
GPUDevice::~GPUDevice() = default;
OpenCLRuntime* GPUDevice::opencl_runtime() {
return runtime_.get();
GPURuntime* GPUDevice::gpu_runtime() {
return gpu_runtime_.get();
}
Allocator *GPUDevice::allocator() {
......
......@@ -19,6 +19,7 @@
#include "mace/core/device_context.h"
#include "mace/core/device.h"
#include "mace/core/runtime/opencl/gpu_runtime.h"
#include "mace/core/runtime/opencl/opencl_allocator.h"
namespace mace {
......@@ -34,7 +35,7 @@ class GPUDevice : public CPUDevice {
CPUAffinityPolicy cpu_affinity_policy = AFFINITY_NONE,
bool use_gemmlowp = false);
~GPUDevice();
OpenCLRuntime *opencl_runtime() override;
GPURuntime *gpu_runtime() override;
Allocator *allocator() override;
DeviceType device_type() const override;
ScratchBuffer *scratch_buffer() override;
......@@ -42,6 +43,7 @@ class GPUDevice : public CPUDevice {
std::unique_ptr<OpenCLRuntime> runtime_;
std::unique_ptr<OpenCLAllocator> allocator_;
std::unique_ptr<ScratchBuffer> scratch_buffer_;
std::unique_ptr<GPURuntime> gpu_runtime_;
};
} // namespace mace
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/runtime/opencl/gpu_runtime.h"
#include "mace/core/runtime/opencl/scratch_image.h"
namespace mace {
GPURuntime::GPURuntime(mace::OpenCLRuntime *runtime)
: runtime_(runtime),
scratch_image_manager_(new ScratchImageManager),
mem_type_(MemoryType::GPU_IMAGE) {}
GPURuntime::~GPURuntime() = default;
OpenCLRuntime* GPURuntime::opencl_runtime() {
return runtime_;
}
ScratchImageManager* GPURuntime::scratch_image_manager() const {
return scratch_image_manager_.get();
}
bool GPURuntime::UseImageMemory() {
return this->mem_type_ == MemoryType::GPU_IMAGE;
}
void GPURuntime::set_mem_type(MemoryType type) {
this->mem_type_ = type;
}
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_RUNTIME_OPENCL_GPU_RUNTIME_H_
#define MACE_CORE_RUNTIME_OPENCL_GPU_RUNTIME_H_
#include <memory>
#include "mace/proto/mace.pb.h"
namespace mace {
class OpenCLRuntime;
class ScratchImageManager;
class GPURuntime {
public:
explicit GPURuntime(OpenCLRuntime *runtime);
~GPURuntime();
OpenCLRuntime *opencl_runtime();
ScratchImageManager *scratch_image_manager() const;
// TODO(liuqi): remove this function in the future, make decision at runtime.
bool UseImageMemory();
void set_mem_type(MemoryType type);
private:
OpenCLRuntime *runtime_;
std::unique_ptr<ScratchImageManager> scratch_image_manager_;
MemoryType mem_type_;
};
} // namespace mace
#endif // MACE_CORE_RUNTIME_OPENCL_GPU_RUNTIME_H_
......@@ -284,9 +284,7 @@ OpenCLRuntime::OpenCLRuntime(
is_opencl_avaliable_(false),
is_profiling_enabled_(false),
opencl_version_(CL_VER_UNKNOWN),
gpu_type_(UNKNOWN),
mem_type_(MemoryType::GPU_IMAGE),
scratch_image_manager_(new ScratchImageManager) {
gpu_type_(UNKNOWN) {
std::vector<cl::Platform> all_platforms;
cl::Platform::get(&all_platforms);
if (all_platforms.size() == 0) {
......@@ -471,14 +469,6 @@ uint32_t OpenCLRuntime::device_compute_units() const {
return device_compute_units_;
}
bool OpenCLRuntime::UseImageMemory() {
return this->mem_type_ == MemoryType::GPU_IMAGE;
}
void OpenCLRuntime::set_mem_type(MemoryType type) {
this->mem_type_ = type;
}
bool OpenCLRuntime::BuildProgramFromCache(
const std::string &built_program_key,
const std::string &build_options_str,
......@@ -792,8 +782,4 @@ bool OpenCLRuntime::is_profiling_enabled() const {
return is_profiling_enabled_;
}
ScratchImageManager* OpenCLRuntime::scratch_image_manager() const {
return scratch_image_manager_.get();
}
} // namespace mace
......@@ -83,11 +83,7 @@ class OpenCLRuntime {
uint64_t device_global_mem_cache_size() const;
uint32_t device_compute_units() const;
Tuner<uint32_t> *tuner();
ScratchImageManager *scratch_image_manager() const;
bool is_opencl_avaliable();
// TODO(liuqi): remove this function in the future, make decision at runtime.
bool UseImageMemory();
void set_mem_type(MemoryType type);
void GetCallStats(const cl::Event &event, CallStats *stats);
uint64_t GetDeviceMaxWorkGroupSize();
......@@ -135,8 +131,6 @@ class OpenCLRuntime {
bool is_profiling_enabled_;
OpenCLVersion opencl_version_;
GPUType gpu_type_;
MemoryType mem_type_;
std::unique_ptr<ScratchImageManager> scratch_image_manager_;
// All OpenCL object must be a pointer and manually deleted before unloading
// OpenCL library.
std::shared_ptr<cl::Context> context_;
......
......@@ -109,7 +109,7 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
(!is_quantize_model && HasQuantizedTensor(net_def))));
#ifdef MACE_ENABLE_OPENCL
diffused_buffer_ = diffused_buffer_ || (device_type == DeviceType::GPU &&
device->opencl_runtime()->GetDeviceMaxMemAllocSize() <=
device->gpu_runtime()->opencl_runtime()->GetDeviceMaxMemAllocSize() <=
static_cast<uint64_t>(model_data_size));
#endif
if (diffused_buffer_) {
......
......@@ -69,8 +69,8 @@ void UnloadModelData(const unsigned char *model_data,
#ifdef MACE_ENABLE_OPENCL
MaceStatus CheckGPUAvalibility(const NetDef *net_def, Device *device) {
// Check OpenCL avaliable
auto runtime = device->opencl_runtime();
if (!runtime->is_opencl_avaliable()) {
auto runtime = device->gpu_runtime();
if (!runtime->opencl_runtime()->is_opencl_avaliable()) {
LOG(WARNING) << "The device does not support OpenCL";
return MaceStatus::MACE_OUT_OF_RESOURCES;
}
......@@ -678,8 +678,8 @@ MaceStatus MaceEngine::Impl::Run(
#ifdef MACE_ENABLE_OPENCL
if (device_type_ == GPU) {
device_->opencl_runtime()->command_queue().finish();
device_->opencl_runtime()->SaveBuiltCLProgram();
device_->gpu_runtime()->opencl_runtime()->command_queue().finish();
device_->gpu_runtime()->opencl_runtime()->SaveBuiltCLProgram();
}
#endif
for (auto &output : *outputs) {
......
......@@ -81,7 +81,7 @@ class ActivationOp<DeviceType::GPU, T> : public Operation {
auto relux_max_limit = static_cast<T>(
Operation::GetOptionalArg<float>("max_limit", 0.0f));
MemoryType mem_type;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
kernel_.reset(
new opencl::image::ActivationKernel<T>(type, relux_max_limit));
......
......@@ -106,7 +106,7 @@ class AddNOp<DeviceType::GPU, T> : public Operation {
public:
explicit AddNOp(OpConstructContext *context)
: Operation(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::AddNKernel<T>);
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -149,7 +149,7 @@ class BatchNormOp<DeviceType::GPU, T> : public Operation {
Operation::GetOptionalArg<std::string>("activation", "NOOP"));
float relux_max_limit = Operation::GetOptionalArg<float>("max_limit", 0.0f);
MemoryType mem_type;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
kernel_.reset(new opencl::image::BatchNormKernel<T>(
epsilon, activation, relux_max_limit));
......
......@@ -265,7 +265,7 @@ class BatchToSpaceNDOp<DeviceType::GPU, T> : public BatchToSpaceOpBase {
public:
explicit BatchToSpaceNDOp(OpConstructContext *context)
: BatchToSpaceOpBase(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::BatchToSpaceKernel<T>);
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -101,7 +101,7 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation {
data_format_(static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", NHWC))) {
MemoryType mem_type;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
kernel_.reset(new opencl::image::BiasAddKernel<T>);
} else {
......
......@@ -84,7 +84,7 @@ class ChannelShuffleOp<DeviceType::GPU, T> : public Operation {
explicit ChannelShuffleOp(OpConstructContext *context)
: Operation(context) {
const int groups = Operation::GetOptionalArg<int>("group", 1);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::ChannelShuffleKernel<T>(groups));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -196,7 +196,7 @@ class ConcatOp<DeviceType::GPU, T> : public ConcatOpBase {
public:
explicit ConcatOp(OpConstructContext *context)
: ConcatOpBase(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::ConcatKernel<T>(axis_));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -963,7 +963,7 @@ class Conv2dOp<DeviceType::GPU, T> : public ConvPool2dOpBase {
relux_max_limit_(Operation::GetOptionalArg<float>("max_limit", 0.0f)),
wino_block_size_(Operation::GetOptionalArg<int>("wino_block_size", 0)) {
MemoryType mem_type;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
kernel_.reset(new opencl::image::Conv2dKernel<T>);
} else {
......@@ -974,7 +974,7 @@ class Conv2dOp<DeviceType::GPU, T> : public ConvPool2dOpBase {
// Transform filter tensor to target format
if ((wino_block_size_ == 2 || wino_block_size_ == 4) &&
(kernel_->CheckUseWinograd(
context->device()->opencl_runtime(),
context->device()->gpu_runtime()->opencl_runtime(),
context->workspace()->GetTensor(
operator_def_->input(1))->shape(),
std::vector<index_t>(operator_def_->output_shape(0).dims().begin(),
......
......@@ -113,7 +113,7 @@ class CropOp<DeviceType::GPU, T> : public Operation {
explicit CropOp(OpConstructContext *context)
: Operation(context) {
const int axis = Operation::GetOptionalArg<int>("axis", 2);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::CropKernel<T>(
axis, Operation::GetRepeatedArgs<int>("offset")));
} else {
......
......@@ -360,7 +360,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
explicit Deconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context) {
MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::Deconv2dKernel<T>);
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -96,7 +96,7 @@ class DepthToSpaceOp<DeviceType::GPU, T> : public Operation {
explicit DepthToSpaceOp(OpConstructContext *context)
: Operation(context) {
int block_size = Operation::GetOptionalArg<int>("block_size", 1);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::DepthToSpaceKernel<T>(block_size));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -492,7 +492,7 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase {
explicit DepthwiseConv2dOp(OpConstructContext *context)
: DepthwiseConv2dOpBase(context) {
MemoryType mem_type;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
kernel_.reset(new opencl::image::DepthwiseConv2dKernel<T>);
} else {
......
......@@ -410,7 +410,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
explicit DepthwiseDeconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context) {
MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::DepthwiseDeconv2dKernel<T>);
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -1088,7 +1088,7 @@ class EltwiseOp<DeviceType::GPU, T> : public Operation {
int32_t scalar_input_index = Operation::GetOptionalArg<int32_t>(
"scalar_input_index", 1);
MemoryType mem_type;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
kernel_.reset(new opencl::image::EltwiseKernel<T>(
type, coeff, scalar_input, scalar_input_index));
......
......@@ -194,7 +194,7 @@ class FullyConnectedOp<DeviceType::GPU, T> : public FullyConnectedOpBase {
explicit FullyConnectedOp(OpConstructContext *context)
: FullyConnectedOpBase(context) {
MemoryType mem_type;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
kernel_.reset(new opencl::image::FullyConnectedKernel<T>);
} else {
......
......@@ -34,7 +34,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation {
Operation::GetOptionalArg<float>("scalar_input",
0.0));
MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::LSTMCellKernel<T>(forget_bias));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -47,7 +47,7 @@ MaceStatus TransformConv2DFilter(
MACE_RETURN_IF_ERROR(output->Resize(transformed_shape));
output->Reshape(input->shape());
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
......@@ -116,7 +116,7 @@ MaceStatus TransformDWConv2DFilter(
MACE_RETURN_IF_ERROR(output->Resize(transformed_shape));
output->Reshape(input->shape());
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
......@@ -173,7 +173,7 @@ MaceStatus TransformArgument(
MACE_RETURN_IF_ERROR(output->Resize(transformed_shape));
output->Reshape(input->shape());
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION
if (kernel->get() == nullptr) {
std::set<std::string> built_options;
......
......@@ -31,7 +31,7 @@ MaceStatus BufferTypeTransform(
Tensor *output) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION
const uint32_t gws =
......
......@@ -43,7 +43,7 @@ MaceStatus Conv2d1x1(OpContext *context,
const index_t in_height = padded_input->dim(1);
const index_t in_width = padded_input->dim(2);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -48,7 +48,7 @@ MaceStatus Conv2dGeneral(OpContext *context,
const index_t filter_height = filter->dim(2);
const index_t filter_width = filter->dim(3);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -48,7 +48,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
const index_t filter_height = filter->dim(2);
const index_t filter_width = filter->dim(3);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION
if (kernel->get() == nullptr) {
......
......@@ -92,7 +92,7 @@ MaceStatus PoolingKernel<T>::Compute(
bool input_changed = !IsVecEqual(input_shape_, input->shape());
input_shape_ = input->shape();
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
// pad input
std::vector<index_t> padded_input_shape = input->shape();
......
......@@ -75,7 +75,7 @@ MaceStatus SoftmaxKernel<T>::Compute(
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION
if (kernel_.get() == nullptr) {
......
......@@ -47,7 +47,7 @@ MaceStatus PadInput(OpContext *context,
static_cast<uint32_t>(padded_height * batch)
};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -66,7 +66,7 @@ MaceStatus ActivationKernel<T>::Compute(
const index_t channel_blocks = RoundUpDiv4(channels);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -57,7 +57,7 @@ MaceStatus AddNKernel<T>::Compute(
const index_t width = input_tensors[0]->dim(2);
const index_t channels = input_tensors[0]->dim(3);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
for (size_t i = 1; i < size; ++i) {
......
......@@ -85,7 +85,7 @@ MaceStatus BatchNormKernel<T>::Compute(
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -68,7 +68,7 @@ MaceStatus BatchToSpaceKernel<T>::Compute(
chan_blk, static_cast<uint32_t>(batch_tensor->dim(2)),
static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -62,7 +62,7 @@ MaceStatus BiasAddKernel<T>::Compute(
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -98,7 +98,7 @@ MaceStatus BufferToImage<T>::Compute(
}
}
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -70,7 +70,7 @@ MaceStatus ChannelShuffleKernel<T>::Compute(
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -65,7 +65,7 @@ MaceStatus Concat2(OpContext *context,
static_cast<uint32_t>(batch * height),
};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......@@ -126,7 +126,7 @@ MaceStatus ConcatN(OpContext *context,
const index_t height = output->dim(1);
const index_t width = output->dim(2);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -95,7 +95,7 @@ extern MaceStatus Conv2dK1x1(OpContext *context,
const index_t width_blocks = RoundUpDiv4(width);
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -83,7 +83,7 @@ extern MaceStatus Conv2dK3x3(OpContext *context,
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
const index_t width_blocks = RoundUpDiv<index_t, 5>(width);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -91,7 +91,7 @@ extern MaceStatus Conv2d(OpContext *context,
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
const index_t width_blocks = RoundUpDiv4(width);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -141,7 +141,7 @@ MaceStatus CropKernel<T>::Compute(
static_cast<uint32_t>(output->dim(0) * output->dim(1))
};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -92,7 +92,7 @@ MaceStatus Deconv2dKernel<T>::Compute(
const int align_w = stride_w - 1 - padding_w;
const int kernel_size = filter->dim(2) * filter->dim(3);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -87,7 +87,7 @@ MaceStatus DepthToSpaceKernel<T>::Compute(
static_cast<uint32_t>(output_width),
static_cast<uint32_t>(output_height * batch)
};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -93,7 +93,7 @@ MaceStatus DepthwiseConv2d(OpContext *context,
static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel->get() == nullptr) {
......
......@@ -98,7 +98,7 @@ MaceStatus DepthwiseDeconv2dKernel<T>::Compute(
const int align_w = stride_w - 1 - padding_w;
const int kernel_size = filter->dim(2) * filter->dim(3);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -117,7 +117,7 @@ MaceStatus EltwiseKernel<T>::Compute(
static_cast<uint32_t>(width),
static_cast<uint32_t>(batch_height_pixels)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
......
......@@ -64,7 +64,7 @@ MaceStatus FullyConnectedKernel<T>::Compute(
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -92,7 +92,7 @@ MaceStatus ImageToBuffer<T>::Compute(OpContext *context,
break;
}
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -71,7 +71,7 @@ MaceStatus LSTMCellKernel<T>::Compute(
const index_t hidden_units = pre_output->dim(1);
const index_t w_blocks = hidden_units >> 2;
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -82,7 +82,7 @@ MaceStatus MatMulKernel<T>::Compute(
static_cast<uint32_t>(height_blocks * batch),
};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -80,7 +80,7 @@ MaceStatus PadKernel<T>::Compute(
const index_t channel_blocks = RoundUpDiv4(channels);
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -112,7 +112,7 @@ MaceStatus PoolingKernel<T>::Compute(
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -76,7 +76,7 @@ MaceStatus ReduceMeanKernel<T>::Compute(
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -102,7 +102,7 @@ MaceStatus ResizeBicubicKernel<T>::Compute(
static_cast<uint32_t>(out_width),
static_cast<uint32_t>(out_height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -107,7 +107,7 @@ MaceStatus ResizeBilinearKernel<T>::Compute(
static_cast<uint32_t>(out_width),
static_cast<uint32_t>(out_height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -102,7 +102,7 @@ MaceStatus SoftmaxKernel<T>::Compute(
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -66,7 +66,7 @@ MaceStatus SpaceToBatchKernel<T>::Compute(
chan_blk, static_cast<uint32_t>(batch_tensor->dim(2)),
static_cast<uint32_t>(batch_tensor->dim(0) * batch_tensor->dim(1))};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -79,7 +79,7 @@ MaceStatus SpaceToDepthKernel<T>::Compute(
&image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, image_shape));
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -70,7 +70,7 @@ MaceStatus SplitKernel<T>::Compute(
output_list[i]->ResizeImage(output_shape, image_shape));
}
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -72,7 +72,7 @@ MaceStatus SqrDiffMeanKernel<T>::Compute(
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
......
......@@ -37,7 +37,7 @@ MaceStatus WinogradInputTransform(OpContext *context,
Tensor *output_tensor,
uint32_t *kwg_size,
StatsFuture *future) {
OpenCLRuntime *runtime = context->device()->opencl_runtime();
OpenCLRuntime *runtime = context->device()->gpu_runtime()->opencl_runtime();
const index_t out_width = output_tensor->dim(2);
MACE_OUT_OF_RANGE_DEFINITION;
......@@ -119,7 +119,7 @@ MaceStatus WinogradOutputTransform(OpContext *context,
Tensor *output_tensor,
uint32_t *kwg_size,
StatsFuture *future) {
OpenCLRuntime *runtime = context->device()->opencl_runtime();
OpenCLRuntime *runtime = context->device()->gpu_runtime()->opencl_runtime();
auto &output_shape = output_tensor->shape();
MACE_OUT_OF_RANGE_DEFINITION;
......@@ -227,8 +227,9 @@ extern MaceStatus WinogradConv2dK3x3S1(OpContext *context,
std::vector<index_t> *prev_input_shape,
Tensor *output,
uint32_t *kwg_size[3]) {
OpenCLRuntime *runtime = context->device()->opencl_runtime();
ScratchImageManager *scratch_manager = runtime->scratch_image_manager();
OpenCLRuntime *runtime = context->device()->gpu_runtime()->opencl_runtime();
ScratchImageManager *scratch_manager =
context->device()->gpu_runtime()->scratch_image_manager();
StatsFuture t_input_future, mm_future, t_output_future;
bool input_changed = !IsVecEqual(*prev_input_shape, input->shape());
*prev_input_shape = input->shape();
......
......@@ -35,7 +35,7 @@ MaceStatus BufferToImageOpImpl(OpContext *context,
uint32_t gws[2] = {static_cast<uint32_t>(image_shape[0]),
static_cast<uint32_t>(image_shape[1])};
auto runtime = context->device()->opencl_runtime();
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
std::string kernel_name = "in_out_buffer_to_image";
std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
......
......@@ -206,7 +206,7 @@ MaceStatus OpsTestNet::RunOp(mace::DeviceType device) {
auto opencl_mem_types = OpTestContext::Get()->opencl_mem_types();
for (auto type : opencl_mem_types) {
OpTestContext::Get()->GetDevice(device)
->opencl_runtime()->set_mem_type(type);
->gpu_runtime()->set_mem_type(type);
Setup(device);
MACE_RETURN_IF_ERROR(Run());
}
......@@ -242,8 +242,8 @@ MaceStatus OpsTestNet::RunNet(const mace::NetDef &net_def,
void OpsTestNet::Sync() {
#ifdef MACE_ENABLE_OPENCL
if (net_ && device_type_ == DeviceType::GPU) {
OpTestContext::Get()->GetDevice(DeviceType::GPU)->opencl_runtime()
->command_queue().finish();
OpTestContext::Get()->GetDevice(DeviceType::GPU)->gpu_runtime()
->opencl_runtime()->command_queue().finish();
}
#endif
}
......
......@@ -97,7 +97,7 @@ class PadOp<DeviceType::GPU, T> : public Operation {
std::vector<int> paddings = Operation::GetRepeatedArgs<int>("paddings");
float constant_value = Operation::GetOptionalArg<float>(
"constant_value", 0.0);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::PadKernel<T>(paddings, constant_value));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -429,7 +429,7 @@ class PoolingOp<DeviceType::GPU, T> : public PoolingOpBase {
public:
explicit PoolingOp(OpConstructContext *context)
: PoolingOpBase(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::PoolingKernel<T>);
} else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
......
......@@ -246,7 +246,7 @@ class ReduceMeanOp<DeviceType::GPU, T> : public ReduceMeanOpBase {
public:
explicit ReduceMeanOp(OpConstructContext *context)
: ReduceMeanOpBase(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::ReduceMeanKernel<T>(axis_, keep_dims_));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -195,7 +195,7 @@ class ResizeBicubicOp<DeviceType::GPU, T> : public Operation {
std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>(
"size", {-1, -1});
MACE_CHECK(size.size() == 2);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::ResizeBicubicKernel<T>(align_corners,
size[0],
size[1]));
......
......@@ -331,7 +331,7 @@ class ResizeBilinearOp<DeviceType::GPU, T> : public Operation {
std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>(
"size", {-1, -1});
MACE_CHECK(size.size() == 2);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::ResizeBilinearKernel<T>(align_corners,
size[0],
size[1]));
......
......@@ -364,7 +364,7 @@ class SoftmaxOp<DeviceType::GPU, T> : public Operation {
public:
explicit SoftmaxOp(OpConstructContext *context)
: Operation(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::SoftmaxKernel<T>);
} else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
......
......@@ -308,7 +308,7 @@ class SpaceToBatchNDOp<DeviceType::GPU, T> : public SpaceToBatchOpBase {
public:
explicit SpaceToBatchNDOp(OpConstructContext *context)
: SpaceToBatchOpBase(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::SpaceToBatchKernel<T>);
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -94,7 +94,7 @@ class SpaceToDepthOp<DeviceType::GPU, T> : public Operation {
explicit SpaceToDepthOp(OpConstructContext *context)
: Operation(context) {
int block_size = Operation::GetOptionalArg<int>("block_size", 1);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::SpaceToDepthKernel<T>(block_size));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -105,7 +105,7 @@ class SplitOp<DeviceType::GPU, T> : public Operation {
explicit SplitOp(OpConstructContext *context)
: Operation(context) {
int32_t axis = Operation::GetOptionalArg<int>("axis", 3);
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::SplitKernel<T>(axis));
} else {
MACE_NOT_IMPLEMENTED;
......
......@@ -82,7 +82,7 @@ class SqrDiffMeanOp<DeviceType::GPU, T> : public Operation {
public:
explicit SqrDiffMeanOp(OpConstructContext *context)
: Operation(context) {
if (context->device()->opencl_runtime()->UseImageMemory()) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::SqrDiffMeanKernel<T>());
} else {
MACE_NOT_IMPLEMENTED;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册