提交 6751c7be 编写于 作者: 吴承辉

Merge branch 'mace-status' into 'master'

Add mace status to op run

See merge request !529
......@@ -155,18 +155,13 @@ MaceStatus MaceEngine::Impl::Init(
}
} else {
#endif
MaceStatus status =
ws_->LoadModelTensor(*net_def, device_type_, model_data);
if (status != MaceStatus::MACE_SUCCESS) {
return status;
}
MACE_FAILURE_RETURN(ws_->LoadModelTensor(
*net_def, device_type_, model_data));
// Init model
auto net = CreateNet(op_registry_, *net_def, ws_.get(), device_type_,
NetMode::INIT);
if (!net->Run()) {
LOG(FATAL) << "Net init run failed";
}
MACE_FAILURE_RETURN(net->Run());
net_ = CreateNet(op_registry_, *net_def, ws_.get(), device_type_);
#ifdef MACE_ENABLE_HEXAGON
}
......@@ -226,9 +221,7 @@ MaceStatus MaceEngine::Impl::Run(
hexagon_controller_->ExecuteGraph(*input_tensors[0], output_tensors[0]);
} else {
#endif
if (!net_->Run(run_metadata)) {
LOG(FATAL) << "Net run failed";
}
MACE_FAILURE_RETURN(net_->Run(run_metadata));
#ifdef MACE_ENABLE_HEXAGON
}
#endif
......
......@@ -57,7 +57,7 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
}
}
bool SerialNet::Run(RunMetadata *run_metadata) {
MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
MACE_MEMORY_LOGGING_GUARD();
MACE_LATENCY_LOGGER(1, "Running net");
for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
......@@ -68,11 +68,10 @@ bool SerialNet::Run(RunMetadata *run_metadata) {
(run_metadata != nullptr ||
std::distance(iter, operators_.end()) == 1));
bool ret;
CallStats call_stats;
if (future_wait) {
StatsFuture future;
ret = op->Run(&future);
MACE_FAILURE_RETURN(op->Run(&future));
if (run_metadata != nullptr) {
future.wait_fn(&call_stats);
} else {
......@@ -80,10 +79,10 @@ bool SerialNet::Run(RunMetadata *run_metadata) {
}
} else if (run_metadata != nullptr) {
call_stats.start_micros = NowMicros();
ret = op->Run(nullptr);
MACE_FAILURE_RETURN(op->Run(nullptr));
call_stats.end_micros = NowMicros();
} else {
ret = op->Run(nullptr);
MACE_FAILURE_RETURN(op->Run(nullptr));
}
if (run_metadata != nullptr) {
......@@ -117,16 +116,11 @@ bool SerialNet::Run(RunMetadata *run_metadata) {
run_metadata->op_stats.emplace_back(op_stats);
}
if (!ret) {
LOG(ERROR) << "Operator failed: " << op->debug_def().name();
return false;
}
VLOG(3) << "Operator " << op->debug_def().name()
<< " has shape: " << MakeString(op->Output(0)->shape());
}
return true;
return MACE_SUCCESS;
}
std::unique_ptr<NetBase> CreateNet(
......
......@@ -36,7 +36,7 @@ class NetBase {
DeviceType type);
virtual ~NetBase() noexcept {}
virtual bool Run(RunMetadata *run_metadata = nullptr) = 0;
virtual MaceStatus Run(RunMetadata *run_metadata = nullptr) = 0;
const std::string &Name() const { return name_; }
......@@ -55,7 +55,7 @@ class SerialNet : public NetBase {
DeviceType type,
const NetMode mode = NetMode::NORMAL);
bool Run(RunMetadata *run_metadata = nullptr) override;
MaceStatus Run(RunMetadata *run_metadata = nullptr) override;
protected:
std::vector<std::unique_ptr<OperatorBase> > operators_;
......
......@@ -73,7 +73,7 @@ class OperatorBase {
inline const std::vector<Tensor *> &Outputs() { return outputs_; }
// Run Op asynchronously (depends on device), return a future if not nullptr.
virtual bool Run(StatsFuture *future) = 0;
virtual MaceStatus Run(StatsFuture *future) = 0;
inline const OperatorDef &debug_def() const {
MACE_CHECK(has_debug_def(), "operator_def was null!");
......@@ -130,7 +130,7 @@ class Operator : public OperatorBase {
}
}
}
bool Run(StatsFuture *future) override = 0;
MaceStatus Run(StatsFuture *future) override = 0;
~Operator() noexcept override {}
};
......
......@@ -132,7 +132,7 @@ class ActivationFunctor<DeviceType::CPU, float> {
ActivationFunctor(ActivationType type, float relux_max_limit)
: activation_(type), relux_max_limit_(relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
......@@ -144,16 +144,13 @@ class ActivationFunctor<DeviceType::CPU, float> {
const float *alpha_ptr = alpha->data<float>();
const index_t outer_size = output->dim(0);
const index_t inner_size = output->dim(2) * output->dim(3);
PReLUActivation(input_ptr,
outer_size,
input->dim(1),
inner_size,
alpha_ptr,
output_ptr);
PReLUActivation(input_ptr, outer_size, input->dim(1), inner_size,
alpha_ptr, output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
}
return MACE_SUCCESS;
}
private:
......@@ -168,7 +165,7 @@ class ActivationFunctor<DeviceType::GPU, T> {
ActivationFunctor(ActivationType type, T relux_max_limit)
: activation_(type), relux_max_limit_(static_cast<T>(relux_max_limit)) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future);
......
......@@ -36,11 +36,11 @@ constexpr int kCostPerGroup = 1024;
template <DeviceType D, typename T>
struct AddNFunctor {
void operator()(const std::vector<const Tensor *> &input_tensors,
MaceStatus operator()(const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor,
StatsFuture *future) {
MACE_UNUSED(future);
output_tensor->ResizeLike(input_tensors[0]);
MACE_FAILURE_RETURN(output_tensor->ResizeLike(input_tensors[0]));
index_t size = output_tensor->size();
Tensor::MappingGuard output_map(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
......@@ -89,13 +89,14 @@ struct AddNFunctor {
}
}
}
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct AddNFunctor<DeviceType::GPU, T> {
void operator()(const std::vector<const Tensor *> &input_tensors,
MaceStatus operator()(const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor,
StatsFuture *future);
......
......@@ -20,7 +20,7 @@
namespace mace {
namespace kernels {
extern void Conv2dNeonK1x1S1(const float *input,
void Conv2dNeonK1x1S1(const float *input,
const float *filter,
const index_t batch,
const index_t height,
......@@ -29,61 +29,61 @@ extern void Conv2dNeonK1x1S1(const float *input,
const index_t out_channels,
float *output);
extern void Conv2dNeonK3x3S1(const float *input,
void Conv2dNeonK3x3S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK3x3S2(const float *input,
void Conv2dNeonK3x3S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK5x5S1(const float *input,
void Conv2dNeonK5x5S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK1x7S1(const float *input,
void Conv2dNeonK1x7S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK7x1S1(const float *input,
void Conv2dNeonK7x1S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK7x7S1(const float *input,
void Conv2dNeonK7x7S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK7x7S2(const float *input,
void Conv2dNeonK7x7S2(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK7x7S3(const float *input,
void Conv2dNeonK7x7S3(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK1x15S1(const float *input,
void Conv2dNeonK1x15S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
float *output);
extern void Conv2dNeonK15x1S1(const float *input,
void Conv2dNeonK15x1S1(const float *input,
const float *filter,
const index_t *in_shape,
const index_t *out_shape,
......
......@@ -56,7 +56,7 @@ struct BatchNormFunctor<DeviceType::CPU, float> : BatchNormFunctorBase {
const float relux_max_limit)
: BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
......@@ -124,6 +124,8 @@ struct BatchNormFunctor<DeviceType::CPU, float> : BatchNormFunctorBase {
}
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
return MACE_SUCCESS;
}
};
......@@ -134,7 +136,7 @@ struct BatchNormFunctor<DeviceType::GPU, T> : BatchNormFunctorBase {
const ActivationType activation,
const float relux_max_limit)
: BatchNormFunctorBase(folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
......
......@@ -34,7 +34,7 @@ struct BiasAddFunctor;
template<>
struct BiasAddFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
......@@ -61,13 +61,15 @@ struct BiasAddFunctor<DeviceType::CPU, float> {
}
}
}
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template<typename T>
struct BiasAddFunctor<DeviceType::GPU, T> {
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
......
......@@ -33,7 +33,7 @@ struct BufferToImageFunctorBase {
template <DeviceType D, typename T>
struct BufferToImageFunctor : BufferToImageFunctorBase {
BufferToImageFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future) {
......@@ -42,13 +42,14 @@ struct BufferToImageFunctor : BufferToImageFunctorBase {
MACE_UNUSED(output);
MACE_UNUSED(future);
MACE_NOT_IMPLEMENTED;
return MACE_SUCCESS;
}
};
template <typename T>
struct BufferToImageFunctor<DeviceType::GPU, T> : BufferToImageFunctorBase {
BufferToImageFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future);
......
......@@ -28,11 +28,11 @@ template<DeviceType D, typename T>
struct ChannelShuffleFunctor {
explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
output->ResizeLike(input);
MACE_FAILURE_RETURN(output->ResizeLike(input));
Tensor::MappingGuard logits_guard(input);
Tensor::MappingGuard output_guard(output);
......@@ -61,6 +61,8 @@ struct ChannelShuffleFunctor {
}
}
}
return MACE_SUCCESS;
}
const int groups_;
......@@ -71,7 +73,9 @@ template<typename T>
struct ChannelShuffleFunctor<DeviceType::GPU, T> {
explicit ChannelShuffleFunctor(const int groups) : groups_(groups) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
......
......@@ -40,7 +40,7 @@ template <DeviceType D, typename T>
struct ConcatFunctor : ConcatFunctorBase {
explicit ConcatFunctor(const int32_t axis) : ConcatFunctorBase(axis) {}
void operator()(const std::vector<const Tensor *> &input_list,
MaceStatus operator()(const std::vector<const Tensor *> &input_list,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
......@@ -68,7 +68,7 @@ struct ConcatFunctor : ConcatFunctorBase {
outer_sizes[i] = input->size() / inner_size;
output_shape[axis_] += input->dim(axis_);
}
output->Resize(output_shape);
MACE_FAILURE_RETURN(output->Resize(output_shape));
T *output_ptr = output->mutable_data<T>();
......@@ -89,6 +89,8 @@ struct ConcatFunctor : ConcatFunctorBase {
}
}
}
return MACE_SUCCESS;
}
};
......@@ -97,7 +99,7 @@ template <typename T>
struct ConcatFunctor<DeviceType::GPU, T> : ConcatFunctorBase {
explicit ConcatFunctor(const int32_t axis) : ConcatFunctorBase(axis) {}
void operator()(const std::vector<const Tensor *> &input_list,
MaceStatus operator()(const std::vector<const Tensor *> &input_list,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
......
......@@ -256,7 +256,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
} // b
}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
......@@ -296,7 +296,7 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
MACE_FAILURE_RETURN(output->Resize(output_shape));
index_t batch = output->dim(0);
index_t channels = output->dim(1);
......@@ -497,7 +497,8 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
if (is_filter_transformed_) {
transformed_filter_ptr = filter_data;
} else {
transformed_filter_.Resize(transformed_filter_shape);
MACE_FAILURE_RETURN(transformed_filter_.Resize(
transformed_filter_shape));
switch (winograd_out_tile_size) {
case 2:
TransformFilter4x4(filter_data,
......@@ -643,12 +644,12 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
const Tensor *pad_input_ptr = input;
if (extra_input_height != input_height
|| extra_input_width != input_width) {
ConstructNCHWInputWithSpecificPadding(input,
MACE_FAILURE_RETURN(ConstructNCHWInputWithSpecificPadding(input,
pad_top,
pad_bottom,
pad_left,
pad_right,
&padded_input);
&padded_input));
pad_input_ptr = &padded_input;
}
......@@ -701,6 +702,8 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
return MACE_SUCCESS;
}
Tensor transformed_filter_;
......@@ -729,7 +732,7 @@ struct Conv2dFunctor<DeviceType::GPU, T> : Conv2dFunctorBase {
MACE_UNUSED(scratch);
}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
......
......@@ -286,7 +286,7 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
}
void ConstructNCHWInputWithPadding(const Tensor *input_tensor,
MaceStatus ConstructNCHWInputWithPadding(const Tensor *input_tensor,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value) {
......@@ -306,7 +306,7 @@ void ConstructNCHWInputWithPadding(const Tensor *input_tensor,
const int padded_top = paddings[0] / 2;
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
MACE_FAILURE_RETURN(output_tensor->Resize(output_shape));
Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
......@@ -356,9 +356,11 @@ void ConstructNCHWInputWithPadding(const Tensor *input_tensor,
}
}
}
return MACE_SUCCESS;
}
void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
MaceStatus ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
const int pad_top,
const int pad_bottom,
const int pad_left,
......@@ -376,7 +378,7 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
const int pad_width = pad_left + pad_right;
std::vector<index_t> output_shape(
{batch, channels, height + pad_height, width + pad_width});
output_tensor->Resize(output_shape);
MACE_FAILURE_RETURN(output_tensor->Resize(output_shape));
output_tensor->Clear();
Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
......@@ -400,10 +402,12 @@ void ConstructNCHWInputWithSpecificPadding(const Tensor *input_tensor,
// Skip the padded bottom in this channel and top in the next channel
}
}
return MACE_SUCCESS;
}
void ConstructNHWCInputWithPadding(const Tensor *input_tensor,
MaceStatus ConstructNHWCInputWithPadding(const Tensor *input_tensor,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value) {
......@@ -424,7 +428,7 @@ void ConstructNHWCInputWithPadding(const Tensor *input_tensor,
const int padded_top = paddings[0] / 2;
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
MACE_FAILURE_RETURN(output_tensor->Resize(output_shape));
Tensor::MappingGuard padded_output_mapper(output_tensor);
float *output_data = output_tensor->mutable_data<float>();
......@@ -450,6 +454,8 @@ void ConstructNHWCInputWithPadding(const Tensor *input_tensor,
}
}
}
return MACE_SUCCESS;
}
} // namespace kernels
......
......@@ -71,17 +71,17 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
Padding padding,
int *padding_size);
void ConstructNCHWInputWithSpecificPadding(const Tensor *input,
MaceStatus ConstructNCHWInputWithSpecificPadding(const Tensor *input,
const int pad_top, const int pad_bottom,
const int pad_left, const int pad_right,
Tensor *output_tensor);
void ConstructNCHWInputWithPadding(const Tensor *input,
MaceStatus ConstructNCHWInputWithPadding(const Tensor *input,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value = false);
void ConstructNHWCInputWithPadding(const Tensor *input,
MaceStatus ConstructNHWCInputWithPadding(const Tensor *input,
const int *paddings,
Tensor *output_tensor,
bool padding_same_value = false);
......
......@@ -226,7 +226,7 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
activation,
relux_max_limit) {}
void operator()(const Tensor *input, // NCHW
MaceStatus operator()(const Tensor *input, // NCHW
const Tensor *filter, // OIHW
const Tensor *bias,
Tensor *output,
......@@ -250,7 +250,7 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
strides_, padding_type_,
output_shape.data(),
paddings_.data(), true);
output->Resize(output_shape);
MACE_FAILURE_RETURN(output->Resize(output_shape));
} else {
output_shape_.clear();
output_shape_ = std::vector<index_t>(4, 0);
......@@ -259,7 +259,7 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
strides_,
output_shape_.data(),
paddings_.data(), true);
output->Resize(output_shape_);
MACE_FAILURE_RETURN(output->Resize(output_shape_));
}
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
......@@ -298,6 +298,8 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
output->size(),
activation_,
relux_max_limit_);
return MACE_SUCCESS;
}
};
......@@ -317,7 +319,7 @@ struct Deconv2dFunctor<DeviceType::GPU, T> : Deconv2dFunctorBase {
activation,
relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
......
......@@ -32,7 +32,9 @@ template<DeviceType D, typename T>
struct DepthToSpaceOpFunctor {
explicit DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
const index_t batch_size = input->dim(0);
const index_t input_depth = input->dim(1);
......@@ -53,7 +55,7 @@ struct DepthToSpaceOpFunctor {
std::vector<index_t> output_shape = {batch_size, output_depth,
output_height, output_width};
output->Resize(output_shape);
MACE_FAILURE_RETURN(output->Resize(output_shape));
Tensor::MappingGuard logits_guard(input);
Tensor::MappingGuard output_guard(output);
......@@ -75,7 +77,8 @@ struct DepthToSpaceOpFunctor {
const index_t in_d = d + offset_d;
const index_t o_index =
((b * output_depth + d) * output_height + h) * output_width + w;
((b * output_depth + d) * output_height + h) * output_width
+ w;
const index_t i_index =
((b * input_depth + in_d) * input_height + in_h) * input_width
+ in_w;
......@@ -110,6 +113,8 @@ struct DepthToSpaceOpFunctor {
}
}
}
return MACE_SUCCESS;
}
const int block_size_;
......@@ -121,7 +126,9 @@ template<typename T>
struct DepthToSpaceOpFunctor<DeviceType::GPU, T> {
DepthToSpaceOpFunctor(const int block_size, bool d2s)
: block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
const int block_size_;
bool d2s_;
......
......@@ -127,7 +127,7 @@ struct DepthwiseConv2dFunctor<DeviceType::CPU, float>
}
}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
......@@ -161,7 +161,7 @@ struct DepthwiseConv2dFunctor<DeviceType::CPU, float>
RoundType::FLOOR,
output_shape.data());
}
output->Resize(output_shape);
MACE_FAILURE_RETURN(output->Resize(output_shape));
output->Clear();
index_t batch = output->dim(0);
......@@ -275,6 +275,8 @@ struct DepthwiseConv2dFunctor<DeviceType::CPU, float>
DoActivation(output_data, output_data, output->size(), activation_,
relux_max_limit_);
return MACE_SUCCESS;
}
};
......@@ -295,7 +297,7 @@ struct DepthwiseConv2dFunctor<DeviceType::GPU, T>
activation,
relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
......
......@@ -466,7 +466,7 @@ struct EltwiseFunctor<DeviceType::CPU, float>: EltwiseFunctorBase {
const float value)
: EltwiseFunctorBase(type, coeff, value) {}
void operator()(const Tensor *input0,
MaceStatus operator()(const Tensor *input0,
const Tensor *input1,
Tensor *output,
StatsFuture *future) {
......@@ -494,7 +494,7 @@ struct EltwiseFunctor<DeviceType::CPU, float>: EltwiseFunctorBase {
}
}
}
output->ResizeLike(input0);
MACE_FAILURE_RETURN(output->ResizeLike(input0));
Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard output_guard(output);
......@@ -530,6 +530,8 @@ struct EltwiseFunctor<DeviceType::CPU, float>: EltwiseFunctorBase {
}
}
}
return MACE_SUCCESS;
}
};
......@@ -541,7 +543,7 @@ struct EltwiseFunctor<DeviceType::GPU, T> : EltwiseFunctorBase {
const float value)
: EltwiseFunctorBase(type, coeff, value) {}
void operator()(const Tensor *input0,
MaceStatus operator()(const Tensor *input0,
const Tensor *input1,
Tensor *output,
StatsFuture *future);
......
......@@ -50,14 +50,14 @@ struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
const float relux_max_limit)
: FullyConnectedBase(activation, relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape);
MACE_FAILURE_RETURN(output->Resize(output_shape));
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1) * weight->dim(2) * weight->dim(3);
const index_t output_size = weight->dim(0);
......@@ -80,6 +80,8 @@ struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
DoActivation(output_ptr, output_ptr, output->size(), activation_,
relux_max_limit_);
return MACE_SUCCESS;
}
};
......@@ -90,7 +92,7 @@ struct FullyConnectedFunctor<DeviceType::GPU, T> : FullyConnectedBase {
const float relux_max_limit)
: FullyConnectedBase(activation, relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *weight,
const Tensor *bias,
Tensor *output,
......
......@@ -33,7 +33,7 @@ struct ImageToBufferFunctorBase {
template <DeviceType D, typename T>
struct ImageToBufferFunctor : ImageToBufferFunctorBase {
ImageToBufferFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future) {
......@@ -42,13 +42,14 @@ struct ImageToBufferFunctor : ImageToBufferFunctorBase {
MACE_UNUSED(output);
MACE_UNUSED(future);
MACE_NOT_IMPLEMENTED;
return MACE_SUCCESS;
}
};
template <typename T>
struct ImageToBufferFunctor<DeviceType::GPU, T> : ImageToBufferFunctorBase {
ImageToBufferFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future);
......
......@@ -35,7 +35,7 @@ struct LocalResponseNormFunctor;
template<>
struct LocalResponseNormFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
int depth_radius,
float bias,
float alpha,
......@@ -74,6 +74,8 @@ struct LocalResponseNormFunctor<DeviceType::CPU, float> {
}
}
}
return MACE_SUCCESS;
}
};
......
......@@ -38,13 +38,13 @@ namespace kernels {
template<DeviceType D, typename T>
struct MatMulFunctor {
void operator()(const Tensor *A,
MaceStatus operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
StatsFuture *future) {
MACE_UNUSED(future);
std::vector<index_t> c_shape = {A->dim(0), A->dim(1), B->dim(2), 1};
C->Resize(c_shape);
MACE_FAILURE_RETURN(C->Resize(c_shape));
Tensor::MappingGuard guarda(A);
Tensor::MappingGuard guardb(B);
......@@ -63,13 +63,15 @@ struct MatMulFunctor {
memset(c_ptr_base, 0, batch * height * width * sizeof(T));
Gemm(a_ptr_base, b_ptr_base, batch, height, K, width, c_ptr_base);
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template<typename T>
struct MatMulFunctor<DeviceType::GPU, T> {
void operator()(const Tensor *A,
MaceStatus operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
StatsFuture *future);
......
......@@ -21,8 +21,9 @@
namespace mace {
namespace kernels {
template <typename T>
void ActivationFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
template<typename T>
MaceStatus ActivationFunctor<DeviceType::GPU,
T>::operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
......@@ -55,28 +56,22 @@ void ActivationFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
switch (activation_) {
case RELU:
tuning_key_prefix_ = "relu_opencl_kernel";
case RELU:tuning_key_prefix_ = "relu_opencl_kernel";
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
tuning_key_prefix_ = "relux_opencl_kernel";
case RELUX:tuning_key_prefix_ = "relux_opencl_kernel";
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
tuning_key_prefix_ = "prelu_opencl_kernel";
case PRELU:tuning_key_prefix_ = "prelu_opencl_kernel";
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
tuning_key_prefix_ = "tanh_opencl_kernel";
case TANH:tuning_key_prefix_ = "tanh_opencl_kernel";
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
tuning_key_prefix_ = "sigmoid_opencl_kernel";
case SIGMOID:tuning_key_prefix_ = "sigmoid_opencl_kernel";
built_options.emplace("-DUSE_SIGMOID");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation_;
default:LOG(FATAL) << "Unknown activation type: " << activation_;
}
kernel_ = runtime->BuildKernel("activation", kernel_name, built_options);
......@@ -122,9 +117,13 @@ void ActivationFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct ActivationFunctor<DeviceType::GPU, float>;
template struct ActivationFunctor<DeviceType::GPU, half>;
template
struct ActivationFunctor<DeviceType::GPU, float>;
template
struct ActivationFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
......@@ -22,7 +22,7 @@ namespace mace {
namespace kernels {
template <typename T>
void AddNFunctor<DeviceType::GPU, T>::operator()(
MaceStatus AddNFunctor<DeviceType::GPU, T>::operator()(
const std::vector<const Tensor *> &input_tensors,
Tensor *output_tensor,
StatsFuture *future) {
......@@ -87,7 +87,8 @@ void AddNFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output_tensor->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(output_tensor->ResizeImage(output_shape,
output_image_shape));
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
......@@ -118,6 +119,8 @@ void AddNFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct AddNFunctor<DeviceType::GPU, float>;
......
......@@ -23,7 +23,7 @@ namespace mace {
namespace kernels {
template <typename T>
void BatchNormFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MaceStatus BatchNormFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
......@@ -132,6 +132,8 @@ void BatchNormFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct BatchNormFunctor<DeviceType::GPU, float>;
......
......@@ -22,7 +22,7 @@ namespace mace {
namespace kernels {
template <typename T>
void BiasAddFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MaceStatus BiasAddFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
......@@ -115,6 +115,8 @@ void BiasAddFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
}
};
}
return MACE_SUCCESS;
}
template struct BiasAddFunctor<DeviceType::GPU, float>;
......
......@@ -20,7 +20,7 @@ namespace mace {
namespace kernels {
template <typename T>
void BufferToImageFunctor<DeviceType::GPU, T>::operator()(
MaceStatus BufferToImageFunctor<DeviceType::GPU, T>::operator()(
const Tensor *buffer,
const BufferType type,
Tensor *image,
......@@ -30,9 +30,9 @@ void BufferToImageFunctor<DeviceType::GPU, T>::operator()(
CalImage2DShape(buffer->shape(), type, &image_shape);
if (type == WINOGRAD_FILTER) {
std::vector<index_t> new_shape = CalWinogradShape(buffer->shape(), type);
image->ResizeImage(new_shape, image_shape);
MACE_FAILURE_RETURN(image->ResizeImage(new_shape, image_shape));
} else {
image->ResizeImage(buffer->shape(), image_shape);
MACE_FAILURE_RETURN(image->ResizeImage(buffer->shape(), image_shape));
}
uint32_t gws[2] = {static_cast<uint32_t>(image_shape[0]),
......@@ -175,6 +175,8 @@ void BufferToImageFunctor<DeviceType::GPU, T>::operator()(
}
};
}
return MACE_SUCCESS;
}
template struct BufferToImageFunctor<DeviceType::GPU, float>;
......
......@@ -23,11 +23,11 @@ namespace mace {
namespace kernels {
template <typename T>
void ChannelShuffleFunctor<DeviceType::GPU, T>::operator()(
MaceStatus ChannelShuffleFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
Tensor *output,
StatsFuture *future) {
output->ResizeLike(input);
MACE_FAILURE_RETURN(output->ResizeLike(input));
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
......@@ -103,6 +103,8 @@ void ChannelShuffleFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template
......
......@@ -235,7 +235,7 @@ static void ConcatN(cl::Kernel *kernel,
}
template <typename T>
void ConcatFunctor<DeviceType::GPU, T>::operator()(
MaceStatus ConcatFunctor<DeviceType::GPU, T>::operator()(
const std::vector<const Tensor *> &input_list,
Tensor *output,
StatsFuture *future) {
......@@ -266,7 +266,7 @@ void ConcatFunctor<DeviceType::GPU, T>::operator()(
"Dimensions of inputs should be divisible by 4 when inputs_count > 2.");
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output->ResizeImage(output_shape, image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, image_shape));
switch (inputs_count) {
case 2:
......@@ -281,6 +281,8 @@ void ConcatFunctor<DeviceType::GPU, T>::operator()(
MACE_NOT_IMPLEMENTED;
}
}
return MACE_SUCCESS;
}
template struct ConcatFunctor<DeviceType::GPU, float>;
......
......@@ -67,7 +67,7 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
std::unique_ptr<BufferBase> *kernel_error);
template <typename T>
void Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MaceStatus Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
......@@ -111,7 +111,7 @@ void Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, output_image_shape));
if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1] != nullptr) {
......@@ -126,6 +126,8 @@ void Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
DataTypeToEnum<T>::value, &input_shape_, output, future,
&kwg_size_, &kernel_error_);
}
return MACE_SUCCESS;
}
template struct Conv2dFunctor<DeviceType::GPU, float>;
......
......@@ -154,7 +154,7 @@ void Deconv2dOpencl(cl::Kernel *kernel,
} // namespace
template <typename T>
void Deconv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MaceStatus Deconv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output,
......@@ -185,13 +185,15 @@ void Deconv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape_, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape_, output_image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape_, output_image_shape));
Deconv2dOpencl(&kernel_, input, filter, bias,
strides_[0], paddings_.data(),
activation_, relux_max_limit_,
DataTypeToEnum<T>::value, &input_shape_,
output, future, &kwg_size_, &kernel_error_);
return MACE_SUCCESS;
}
template struct Deconv2dFunctor<DeviceType::GPU, float>;
......
......@@ -23,7 +23,7 @@ namespace mace {
namespace kernels {
template <typename T>
void DepthToSpaceOpFunctor<DeviceType::GPU, T>::operator()(
MaceStatus DepthToSpaceOpFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t input_height = input->dim(1);
......@@ -70,7 +70,7 @@ void DepthToSpaceOpFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output->ResizeImage(output_shape, image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, image_shape));
auto runtime = OpenCLRuntime::Global();
......@@ -144,6 +144,8 @@ void DepthToSpaceOpFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct DepthToSpaceOpFunctor<DeviceType::GPU, float>;
......
......@@ -194,7 +194,7 @@ static void DepthwiseConv2d(cl::Kernel *kernel,
}
template <typename T>
void DepthwiseConv2dFunctor<DeviceType::GPU, T>::operator()(
MaceStatus DepthwiseConv2dFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
const Tensor *filter, /* MIHW */
const Tensor *bias,
......@@ -209,10 +209,9 @@ void DepthwiseConv2dFunctor<DeviceType::GPU, T>::operator()(
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
// TODO(heliangliang) The CPU/NEON kernel should map the buffer
DepthwiseConv2dFunctor<DeviceType::CPU, float>(
return DepthwiseConv2dFunctor<DeviceType::CPU, float>(
strides_, padding_type_, paddings_, dilations_, activation_,
relux_max_limit_)(input, filter, bias, output, future);
return;
}
// Create a fake conv_2d filter to calculate the paddings and output size
......@@ -238,12 +237,14 @@ void DepthwiseConv2dFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, output_image_shape));
DepthwiseConv2d(&kernel_, input, filter, bias, strides_[0], paddings.data(),
dilations_, activation_, relux_max_limit_,
DataTypeToEnum<T>::value, &input_shape_, output, future,
&kwg_size_, &kernel_error_);
return MACE_SUCCESS;
}
template struct DepthwiseConv2dFunctor<DeviceType::GPU, float>;
......
......@@ -21,7 +21,7 @@ namespace mace {
namespace kernels {
template <typename T>
void EltwiseFunctor<DeviceType::GPU, T>::operator()(const Tensor *input0,
MaceStatus EltwiseFunctor<DeviceType::GPU, T>::operator()(const Tensor *input0,
const Tensor *input1,
Tensor *output,
StatsFuture *future) {
......@@ -60,7 +60,7 @@ void EltwiseFunctor<DeviceType::GPU, T>::operator()(const Tensor *input0,
CalImage2DShape(output_shape,
BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, output_image_shape));
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
......@@ -151,6 +151,8 @@ void EltwiseFunctor<DeviceType::GPU, T>::operator()(const Tensor *input0,
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct EltwiseFunctor<DeviceType::GPU, float>;
......
......@@ -282,7 +282,7 @@ void FCWTXKernel(cl::Kernel *kernel,
} // namespace
template <typename T>
void FullyConnectedFunctor<DeviceType::GPU, T>::operator()(
MaceStatus FullyConnectedFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
const Tensor *weight,
const Tensor *bias,
......@@ -292,11 +292,13 @@ void FullyConnectedFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, output_image_shape));
FCWXKernel<T>(&kernel_, input, weight, bias, &input_shape_, output,
activation_, &gws_, &lws_, relux_max_limit_, future,
&kernel_error_);
return MACE_SUCCESS;
}
template struct FullyConnectedFunctor<DeviceType::GPU, float>;
......
......@@ -20,7 +20,7 @@ namespace mace {
namespace kernels {
template <typename T>
void ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
MaceStatus ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
const Tensor *image,
const BufferType type,
Tensor *buffer,
......@@ -28,7 +28,7 @@ void ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> image_shape;
CalImage2DShape(image->shape(), type, &image_shape);
buffer->Resize(image->shape());
MACE_FAILURE_RETURN(buffer->Resize(image->shape()));
uint32_t gws[2] = {static_cast<uint32_t>(image_shape[0]),
static_cast<uint32_t>(image_shape[1])};
......@@ -163,6 +163,8 @@ void ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
}
};
}
return MACE_SUCCESS;
}
template struct ImageToBufferFunctor<DeviceType::GPU, float>;
......
......@@ -21,7 +21,7 @@ namespace mace {
namespace kernels {
template <typename T>
void MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
MaceStatus MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
const Tensor *B,
Tensor *C,
StatsFuture *future) {
......@@ -29,7 +29,7 @@ void MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
std::vector<index_t> c_shape = {A->dim(0), A->dim(1), B->dim(2), 1};
std::vector<size_t> c_image_shape;
CalImage2DShape(c_shape, BufferType::IN_OUT_HEIGHT, &c_image_shape);
C->ResizeImage(c_shape, c_image_shape);
MACE_FAILURE_RETURN(C->ResizeImage(c_shape, c_image_shape));
const index_t batch = C->dim(0);
const index_t height = C->dim(1);
......@@ -98,6 +98,8 @@ void MatMulFunctor<DeviceType::GPU, T>::operator()(const Tensor *A,
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct MatMulFunctor<DeviceType::GPU, float>;
......
......@@ -21,7 +21,7 @@ namespace mace {
namespace kernels {
template<typename T>
void PadFunctor<DeviceType::GPU, T>::operator()(
MaceStatus PadFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
Tensor *output,
StatsFuture *future) {
......@@ -39,7 +39,7 @@ void PadFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output->ResizeImage(output_shape, image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, image_shape));
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
......@@ -114,6 +114,8 @@ void PadFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template
......
......@@ -44,7 +44,7 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws,
} // namespace
template <typename T>
void PoolingFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MaceStatus PoolingFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1)
......@@ -108,7 +108,7 @@ void PoolingFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, output_image_shape));
index_t batch = output->dim(0);
index_t out_height = output->dim(1);
......@@ -169,6 +169,8 @@ void PoolingFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct PoolingFunctor<DeviceType::GPU, float>;
......
......@@ -51,7 +51,7 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws,
} // namespace
template <typename T>
void ResizeBilinearFunctor<DeviceType::GPU, T>::operator()(
MaceStatus ResizeBilinearFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input, Tensor *output, StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t in_height = input->dim(1);
......@@ -100,7 +100,7 @@ void ResizeBilinearFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(output->ResizeImage(output_shape, output_image_shape));
float height_scale =
CalculateResizeScale(in_height, out_height, align_corners_);
......@@ -140,6 +140,8 @@ void ResizeBilinearFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct ResizeBilinearFunctor<DeviceType::GPU, float>;
......
......@@ -21,7 +21,7 @@ namespace mace {
namespace kernels {
template<typename T>
void SliceFunctor<DeviceType::GPU, T>::operator()(
MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future) {
......@@ -36,7 +36,7 @@ void SliceFunctor<DeviceType::GPU, T>::operator()(
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
for (size_t i= 0; i < outputs_count; ++i) {
output_list[i]->ResizeImage(output_shape, image_shape);
MACE_FAILURE_RETURN(output_list[i]->ResizeImage(output_shape, image_shape));
}
auto runtime = OpenCLRuntime::Global();
......@@ -131,6 +131,8 @@ void SliceFunctor<DeviceType::GPU, T>::operator()(
}
};
}
return MACE_SUCCESS;
}
template
......
......@@ -44,7 +44,7 @@ std::vector<uint32_t> LocalWS(const uint32_t *gws,
} // namespace
template <typename T>
void SoftmaxFunctor<DeviceType::GPU, T>::operator()(const Tensor *logits,
MaceStatus SoftmaxFunctor<DeviceType::GPU, T>::operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future) {
const index_t batch = logits->dim(0);
......@@ -115,6 +115,8 @@ void SoftmaxFunctor<DeviceType::GPU, T>::operator()(const Tensor *logits,
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct SoftmaxFunctor<DeviceType::GPU, float>;
......
......@@ -25,7 +25,7 @@ namespace mace {
namespace kernels {
template <typename T>
void SpaceToBatchFunctor<DeviceType::GPU, T>::operator()(
MaceStatus SpaceToBatchFunctor<DeviceType::GPU, T>::operator()(
Tensor *space_tensor,
Tensor *batch_tensor,
StatsFuture *future) {
......@@ -45,10 +45,12 @@ void SpaceToBatchFunctor<DeviceType::GPU, T>::operator()(
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL,
&output_image_shape);
if (b2s_) {
space_tensor->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(space_tensor->ResizeImage(output_shape,
output_image_shape));
kernel_name = "batch_to_space";
} else {
batch_tensor->ResizeImage(output_shape, output_image_shape);
MACE_FAILURE_RETURN(batch_tensor->ResizeImage(output_shape,
output_image_shape));
kernel_name = "space_to_batch";
}
const uint32_t chan_blk = RoundUpDiv4<uint32_t>(batch_tensor->dim(3));
......@@ -129,6 +131,8 @@ void SpaceToBatchFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct SpaceToBatchFunctor<DeviceType::GPU, float>;
......
......@@ -22,7 +22,7 @@ namespace mace {
namespace kernels {
template <typename T>
void WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
MaceStatus WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input_tensor, Tensor *output_tensor, StatsFuture *future) {
auto runtime = OpenCLRuntime::Global();
......@@ -78,7 +78,7 @@ void WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
output_shape = {16, input_tensor->dim(3), out_width, 1};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_HEIGHT, &image_shape);
output_tensor->ResizeImage(output_shape, image_shape);
MACE_FAILURE_RETURN(output_tensor->ResizeImage(output_shape, image_shape));
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
......@@ -115,10 +115,12 @@ void WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template <typename T>
void WinogradInverseTransformFunctor<DeviceType::GPU, T>::operator()(
MaceStatus WinogradInverseTransformFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input_tensor,
const Tensor *bias,
Tensor *output_tensor,
......@@ -186,7 +188,7 @@ void WinogradInverseTransformFunctor<DeviceType::GPU, T>::operator()(
input_tensor->dim(1)};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output_tensor->ResizeImage(output_shape, image_shape);
MACE_FAILURE_RETURN(output_tensor->ResizeImage(output_shape, image_shape));
const uint32_t round_h = (height_ + 1) / 2;
const uint32_t round_w = (width_ + 1) / 2;
......@@ -230,6 +232,8 @@ void WinogradInverseTransformFunctor<DeviceType::GPU, T>::operator()(
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
return MACE_SUCCESS;
}
template struct WinogradTransformFunctor<DeviceType::GPU, float>;
......
......@@ -38,23 +38,27 @@ struct PadFunctorBase {
float constant_value_;
};
template <DeviceType D, typename T>
template<DeviceType D, typename T>
struct PadFunctor : public PadFunctorBase {
PadFunctor(const std::vector<int> &paddings,
const float constant_value)
: PadFunctorBase(paddings, constant_value) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(
this->paddings_.size() == static_cast<size_t>(input->dim_size()) * 2);
auto input_shape = input->shape();
output->Resize({input_shape[0] + this->paddings_[0] + this->paddings_[1],
input_shape[1] + this->paddings_[2] + this->paddings_[3],
input_shape[2] + this->paddings_[4] + this->paddings_[5],
input_shape[3] + this->paddings_[6] + this->paddings_[7]});
MACE_FAILURE_RETURN(output->Resize({input_shape[0] + this->paddings_[0]
+ this->paddings_[1],
input_shape[1] + this->paddings_[2]
+ this->paddings_[3],
input_shape[2] + this->paddings_[4]
+ this->paddings_[5],
input_shape[3] + this->paddings_[6]
+ this->paddings_[7]}));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
......@@ -81,6 +85,8 @@ struct PadFunctor : public PadFunctorBase {
}
}
}
return MACE_SUCCESS;
}
};
......@@ -91,7 +97,7 @@ struct PadFunctor<DeviceType::GPU, T> : PadFunctorBase {
const float constant_value)
: PadFunctorBase(paddings, constant_value) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
......
......@@ -167,7 +167,7 @@ struct PoolingFunctor<DeviceType::CPU, float>: PoolingFunctorBase {
}
}
void operator()(const Tensor *input_tensor,
MaceStatus operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future) {
MACE_UNUSED(future);
......@@ -190,7 +190,7 @@ struct PoolingFunctor<DeviceType::CPU, float>: PoolingFunctorBase {
RoundType::CEIL,
output_shape.data());
}
output_tensor->Resize(output_shape);
MACE_FAILURE_RETURN(output_tensor->Resize(output_shape));
Tensor::MappingGuard input_guard(input_tensor);
Tensor::MappingGuard output_guard(output_tensor);
......@@ -220,6 +220,8 @@ struct PoolingFunctor<DeviceType::CPU, float>: PoolingFunctorBase {
} else {
MACE_NOT_IMPLEMENTED;
}
return MACE_SUCCESS;
}
};
......@@ -235,7 +237,7 @@ struct PoolingFunctor<DeviceType::GPU, T> : PoolingFunctorBase {
: PoolingFunctorBase(
pooling_type, kernels, strides, padding_type, paddings, dilations) {
}
void operator()(const Tensor *input_tensor,
MaceStatus operator()(const Tensor *input_tensor,
Tensor *output_tensor,
StatsFuture *future);
......
......@@ -136,7 +136,7 @@ struct ProposalFunctor {
feat_stride_(feat_stride),
anchors_(GenerateAnchors(scales, ratios, base_size)) {}
void operator()(const Tensor *rpn_cls_prob,
MaceStatus operator()(const Tensor *rpn_cls_prob,
const Tensor *rpn_bbox_pred,
const Tensor *img_info_tensor,
Tensor *output,
......@@ -180,7 +180,7 @@ struct ProposalFunctor {
for (int h_idx = 0; h_idx < feat_height; ++h_idx) {
for (int w_idx = 0; w_idx < feat_width; ++w_idx) {
for (int a_idx = 0; a_idx < anchors_size; ++a_idx) {
const int sanc_idx = (h_idx * feat_width + w_idx) * anchors_size
const index_t sanc_idx = (h_idx * feat_width + w_idx) * anchors_size
+ a_idx;
const float width = proposals[sanc_idx][2] -
proposals[sanc_idx][0] + 1;
......@@ -216,7 +216,7 @@ struct ProposalFunctor {
for (int h_idx = 0; h_idx < feat_height; ++h_idx) {
for (int w_idx = 0; w_idx < feat_width; ++w_idx) {
for (int a_idx = 0; a_idx < anchors_size; ++a_idx) {
const int sanc_idx = (h_idx * feat_width + w_idx) * anchors_size
const index_t sanc_idx = (h_idx * feat_width + w_idx) * anchors_size
+ a_idx;
const float width = proposals[sanc_idx][2]
- proposals[sanc_idx][0] + 1;
......@@ -267,7 +267,7 @@ struct ProposalFunctor {
// Our RPN implementation only supports a single input image, so all
// batch inds are 0
size = static_cast<int>(nms_result.size());
output->Resize({size, 1, 1, 5});
MACE_FAILURE_RETURN(output->Resize({size, 1, 1, 5}));
auto output_ptr = output->mutable_data<float>();
#pragma omp parallel for
for (int i = 0; i < size; ++i) {
......@@ -279,6 +279,8 @@ struct ProposalFunctor {
output_ptr[out_idx + 3] = nms_proposals[nms_idx + 2];
output_ptr[out_idx + 4] = nms_proposals[nms_idx + 3];
}
return MACE_SUCCESS;
}
const int min_size_;
......
......@@ -34,7 +34,7 @@ struct PSROIAlignFunctor {
output_dim_(output_dim),
group_size_(group_size) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *rois,
Tensor *output,
StatsFuture *future) {
......@@ -47,10 +47,11 @@ struct PSROIAlignFunctor {
const T *input_ptr = input->data<T>();
const T *rois_ptr = rois->data<T>();
// Number of ROIs
const int num_rois = rois->dim(0);
const int batch_size = input->dim(0);
const index_t num_rois = rois->dim(0);
const index_t batch_size = input->dim(0);
output->Resize({num_rois, pooled_height, pooled_width, output_dim_});
MACE_FAILURE_RETURN(output->Resize({num_rois, pooled_height, pooled_width,
output_dim_}));
T *output_ptr = output->mutable_data<T>();
for (int n = 0; n < num_rois; ++n) {
......@@ -176,6 +177,8 @@ struct PSROIAlignFunctor {
rois_ptr += 5;
output_ptr += pooled_height * pooled_width * output_dim_;
}
return MACE_SUCCESS;
}
const T spatial_scale_;
......
......@@ -74,7 +74,7 @@ template<>
struct QuantizeFunctor<CPU, uint8_t> {
QuantizeFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
Tensor *output,
......@@ -95,6 +95,8 @@ struct QuantizeFunctor<CPU, uint8_t> {
output_data[i] = Saturate<uint8_t>(roundf(
(input_data[i] - in_min_data) * recip_stepsize));
}
return MACE_SUCCESS;
}
};
......@@ -105,7 +107,7 @@ template<>
struct DequantizeFunctor<CPU, uint8_t> {
DequantizeFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
Tensor *output,
......@@ -120,6 +122,8 @@ struct DequantizeFunctor<CPU, uint8_t> {
for (int i = 0; i < input->size(); ++i) {
output_data[i] = in_min_data + stepsize * input_data[i];
}
return MACE_SUCCESS;
}
};
......@@ -130,7 +134,7 @@ template<>
struct RequantizeFunctor<CPU, uint8_t> {
RequantizeFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
const Tensor *rerange_min,
......@@ -189,6 +193,8 @@ struct RequantizeFunctor<CPU, uint8_t> {
Saturate<uint8_t>(roundf(
quantized_out_zero + input_data[i] * step_ratio));
}
return MACE_SUCCESS;
}
};
......
......@@ -31,12 +31,14 @@ template <DeviceType D, typename T>
struct ReshapeFunctor {
ReshapeFunctor() {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const std::vector<index_t> &out_shape,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
output->ResizeWithBuffer(out_shape, input->UnderlyingBuffer());
return MACE_SUCCESS;
}
};
......
......@@ -87,7 +87,8 @@ inline void ResizeImage(const float *images,
for (index_t b = 0; b < batch_size; ++b) {
for (index_t c = 0; c < channels; ++c) {
const float
*channel_input_ptr = images + (b * channels + c) * in_height * in_width;
*channel_input_ptr =
images + (b * channels + c) * in_height * in_width;
float *channel_output_ptr =
output + (b * channels + c) * out_height * out_width;
for (index_t y = 0; y < out_height; ++y) {
......@@ -136,7 +137,9 @@ struct ResizeBilinearFunctor<DeviceType::CPU, float>
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBilinearFunctorBase(size, align_corners) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
const index_t batch = input->dim(0);
const index_t channels = input->dim(1);
......@@ -147,7 +150,7 @@ struct ResizeBilinearFunctor<DeviceType::CPU, float>
index_t out_width = out_width_;
MACE_CHECK(out_height > 0 && out_width > 0);
std::vector<index_t> out_shape{batch, channels, out_height, out_width};
output->Resize(out_shape);
MACE_FAILURE_RETURN(output->Resize(out_shape));
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard output_mapper(output);
......@@ -158,7 +161,7 @@ struct ResizeBilinearFunctor<DeviceType::CPU, float>
std::copy(input_data,
input_data + batch * channels * in_height * in_width,
output_data);
return;
return MACE_SUCCESS;
}
float height_scale =
......@@ -175,6 +178,8 @@ struct ResizeBilinearFunctor<DeviceType::CPU, float>
ResizeImage(input_data, batch, in_height, in_width, out_height, out_width,
channels, xs, ys, output_data);
return MACE_SUCCESS;
}
};
......@@ -185,7 +190,9 @@ struct ResizeBilinearFunctor<DeviceType::GPU, T>
ResizeBilinearFunctor(const std::vector<index_t> &size, bool align_corners)
: ResizeBilinearFunctorBase(size, align_corners) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
......
......@@ -41,7 +41,7 @@ template<DeviceType D, typename T>
struct SliceFunctor : SliceFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future) {
MACE_UNUSED(future);
......@@ -61,7 +61,7 @@ struct SliceFunctor : SliceFunctorBase {
1,
std::multiplies<index_t>());
for (size_t i= 0; i < outputs_count; ++i) {
output_list[i]->Resize(output_shape);
MACE_FAILURE_RETURN(output_list[i]->Resize(output_shape));
output_ptrs[i] = output_list[i]->mutable_data<T>();
}
const T *input_ptr = input->data<T>();
......@@ -82,6 +82,8 @@ struct SliceFunctor : SliceFunctorBase {
input_idx += output_channels * inner_size;
}
}
return MACE_SUCCESS;
}
};
......@@ -90,7 +92,7 @@ template<typename T>
struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future);
cl::Kernel kernel_;
......
......@@ -38,7 +38,9 @@ struct SoftmaxFunctor;
template<>
struct SoftmaxFunctor<DeviceType::CPU, float> {
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
const index_t batch = input->dim(0);
const index_t class_count = input->dim(1);
......@@ -82,13 +84,17 @@ struct SoftmaxFunctor<DeviceType::CPU, float> {
}
} // k
} // b
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template<typename T>
struct SoftmaxFunctor<DeviceType::GPU, T> {
void operator()(const Tensor *logits, Tensor *output, StatsFuture *future);
MaceStatus operator()(const Tensor *logits,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
......
......@@ -140,7 +140,7 @@ struct SpaceToBatchFunctor<DeviceType::CPU, float> : SpaceToBatchFunctorBase {
bool b2s)
: SpaceToBatchFunctorBase(paddings, block_shape, b2s) {}
void operator()(Tensor *space_tensor,
MaceStatus operator()(Tensor *space_tensor,
Tensor *batch_tensor,
StatsFuture *future) {
MACE_UNUSED(future);
......@@ -150,12 +150,12 @@ struct SpaceToBatchFunctor<DeviceType::CPU, float> : SpaceToBatchFunctorBase {
CalculateBatchToSpaceOutputShape(batch_tensor,
DataFormat::NCHW,
output_shape.data());
space_tensor->Resize(output_shape);
MACE_FAILURE_RETURN(space_tensor->Resize(output_shape));
} else {
CalculateSpaceToBatchOutputShape(space_tensor,
DataFormat::NCHW,
output_shape.data());
batch_tensor->Resize(output_shape);
MACE_FAILURE_RETURN(batch_tensor->Resize(output_shape));
}
Tensor::MappingGuard input_guard(space_tensor);
......@@ -312,6 +312,7 @@ struct SpaceToBatchFunctor<DeviceType::CPU, float> : SpaceToBatchFunctorBase {
} // block_h
} // c
}
return MACE_SUCCESS;
}
};
......@@ -323,7 +324,7 @@ struct SpaceToBatchFunctor<DeviceType::GPU, T> : SpaceToBatchFunctorBase {
bool b2s)
: SpaceToBatchFunctorBase(paddings, block_shape, b2s) {}
void operator()(Tensor *space_tensor,
MaceStatus operator()(Tensor *space_tensor,
Tensor *batch_tensor,
StatsFuture *future);
......
......@@ -107,7 +107,9 @@ template<DeviceType D, typename T>
struct TransposeFunctor {
explicit TransposeFunctor(const std::vector<int> &dims) : dims_(dims) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
......@@ -175,6 +177,8 @@ struct TransposeFunctor {
} else {
MACE_NOT_IMPLEMENTED;
}
return MACE_SUCCESS;
}
std::vector<int> dims_;
......
......@@ -44,29 +44,34 @@ struct WinogradTransformFunctorBase {
std::vector<int> paddings_;
};
template <DeviceType D, typename T>
template<DeviceType D, typename T>
struct WinogradTransformFunctor : WinogradTransformFunctorBase {
WinogradTransformFunctor(const Padding &padding_type,
const std::vector<int> &paddings)
: WinogradTransformFunctorBase(padding_type, paddings) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(input);
MACE_UNUSED(output);
MACE_UNUSED(future);
MACE_NOT_IMPLEMENTED;
return MACE_SUCCESS;
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
template<typename T>
struct WinogradTransformFunctor<DeviceType::GPU, T>
: WinogradTransformFunctorBase {
WinogradTransformFunctor(const Padding &padding_type,
const std::vector<int> &paddings)
: WinogradTransformFunctorBase(padding_type, paddings) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
MaceStatus operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
......@@ -94,7 +99,7 @@ struct WinogradInverseTransformFunctorBase {
const float relux_max_limit_;
};
template <DeviceType D, typename T>
template<DeviceType D, typename T>
struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
WinogradInverseTransformFunctor(const int batch,
const int height,
......@@ -104,7 +109,7 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
: WinogradInverseTransformFunctorBase(
batch, height, width, activation, relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
......@@ -113,6 +118,7 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
MACE_UNUSED(output);
MACE_UNUSED(future);
MACE_NOT_IMPLEMENTED;
return MACE_SUCCESS;
}
};
......@@ -128,7 +134,7 @@ struct WinogradInverseTransformFunctor<DeviceType::GPU, T>
: WinogradInverseTransformFunctorBase(
batch, height, width, activation, relux_max_limit) {}
void operator()(const Tensor *input,
MaceStatus operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
......
......@@ -34,15 +34,14 @@ class ActivationOp : public Operator<D, T> {
static_cast<T>(OperatorBase::GetSingleArgument<float>(
"max_limit", 0.0f))) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(0);
const Tensor *alpha_tensor =
this->InputSize() >= 2 ? this->Input(1) : nullptr;
Tensor *output_tensor = this->Output(0);
output_tensor->ResizeLike(input_tensor);
MACE_FAILURE_RETURN(output_tensor->ResizeLike(input_tensor));
functor_(input_tensor, alpha_tensor, output_tensor, future);
return true;
return functor_(input_tensor, alpha_tensor, output_tensor, future);
}
private:
......
......@@ -24,8 +24,7 @@ namespace test {
namespace {
template <DeviceType D, typename T>
void ReluBenchmark(
int iters, int batch, int channels, int height, int width) {
void ReluBenchmark(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
......@@ -94,8 +93,7 @@ BM_RELU(1, 64, 256, 256);
namespace {
template <DeviceType D, typename T>
void ReluxBenchmark(
int iters, int batch, int channels, int height, int width) {
void ReluxBenchmark(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
......@@ -162,8 +160,7 @@ BM_RELUX(1, 64, 256, 256);
namespace {
template <DeviceType D, typename T>
void PreluBenchmark(
int iters, int batch, int channels, int height, int width) {
void PreluBenchmark(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
......@@ -237,8 +234,7 @@ BM_PRELU(1, 64, 256, 256);
namespace {
template <DeviceType D, typename T>
void TanhBenchmark(
int iters, int batch, int channels, int height, int width) {
void TanhBenchmark(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
......
......@@ -29,7 +29,7 @@ class AddNOp : public Operator<D, T> {
AddNOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
Tensor *output_tensor = this->Output(0);
int n = this->inputs_.size();
std::vector<const Tensor *> inputs(n, nullptr);
......@@ -42,9 +42,7 @@ class AddNOp : public Operator<D, T> {
<< ", size: " << inputs[0]->size() << ". Input " << i << ": "
<< MakeString(inputs[i]->shape()) << ", size: " << inputs[i]->size();
}
functor_(inputs, output_tensor, future);
return true;
return functor_(inputs, output_tensor, future);
}
private:
......
......@@ -32,7 +32,7 @@ class BatchNormOp : public Operator<D, T> {
static_cast<float>(1e-4));
}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *scale = this->Input(SCALE);
const Tensor *offset = this->Input(OFFSET);
......@@ -51,10 +51,8 @@ class BatchNormOp : public Operator<D, T> {
var->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, scale, offset, mean, var, epsilon_, output, future);
return true;
MACE_FAILURE_RETURN(output->ResizeLike(input));
return functor_(input, scale, offset, mean, var, epsilon_, output, future);
}
private:
......
......@@ -33,12 +33,11 @@ class BatchToSpaceNDOp : public Operator<D, T> {
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
true) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor = this->Output(OUTPUT);
functor_(space_tensor, const_cast<Tensor *>(batch_tensor),
return functor_(space_tensor, const_cast<Tensor *>(batch_tensor),
future);
return true;
}
private:
......
......@@ -27,7 +27,7 @@ class BiasAddOp : public Operator<D, T> {
BiasAddOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), functor_() {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *bias = this->Input(BIAS);
......@@ -37,10 +37,9 @@ class BiasAddOp : public Operator<D, T> {
bias->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
MACE_FAILURE_RETURN(output->ResizeLike(input));
functor_(input, bias, output, future);
return true;
return functor_(input, bias, output, future);
}
private:
......
......@@ -27,7 +27,7 @@ class BufferToImageOp : public Operator<D, T> {
BufferToImageOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
kernels::BufferType type =
......@@ -35,8 +35,7 @@ class BufferToImageOp : public Operator<D, T> {
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
Tensor *output = this->Output(OUTPUT);
functor_(input_tensor, type, output, future);
return true;
return functor_(input_tensor, type, output, future);
}
private:
......
......@@ -31,7 +31,7 @@ class ChannelShuffleOp : public Operator<D, T> {
group_(OperatorBase::GetSingleArgument<int>("group", 1)),
functor_(this->group_) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
int channels;
......@@ -45,9 +45,7 @@ class ChannelShuffleOp : public Operator<D, T> {
MACE_CHECK(channels % group_ == 0,
"input channels must be an integral multiple of group. ",
input->dim(3));
functor_(input, output, future);
return true;
return functor_(input, output, future);
}
protected:
......
......@@ -30,7 +30,7 @@ class ConcatOp : public Operator<D, T> {
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
MACE_CHECK(this->InputSize() >= 2)
<< "There must be at least two inputs to concat";
const std::vector<const Tensor *> input_list = this->Inputs();
......@@ -44,8 +44,7 @@ class ConcatOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
functor_(input_list, output, future);
return true;
return functor_(input_list, output, future);
}
private:
......
......@@ -42,14 +42,12 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
"is_filter_transformed", false)),
ws->GetScratchBuffer(D)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT);
functor_(input, filter, bias, output, future);
return true;
return functor_(input, filter, bias, output, future);
}
private:
......
......@@ -36,15 +36,13 @@ class Deconv2dOp : public ConvPool2dOpBase<D, T> {
kernels::ActivationType::NOOP,
0.0f) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT);
functor_(input, filter, bias, output, future);
return true;
return functor_(input, filter, bias, output, future);
}
private:
......
......@@ -32,7 +32,7 @@ class DepthToSpaceOp : public Operator<D, T> {
block_size_(OperatorBase::GetSingleArgument<int>("block_size", 1)),
functor_(this->block_size_, true) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
......@@ -50,8 +50,7 @@ class DepthToSpaceOp : public Operator<D, T> {
input_depth);
MACE_CHECK((input_depth % 4) == 0,
"input channel should be dividable by 4");
functor_(input, output, future);
return true;
return functor_(input, output, future);
}
protected:
......
......@@ -40,7 +40,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const Tensor *bias = nullptr;
......@@ -48,8 +48,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
bias = this->Input(BIAS);
}
Tensor *output = this->Output(OUTPUT);
functor_(input, filter, bias, output, future);
return true;
return functor_(input, filter, bias, output, future);
}
private:
......
......@@ -32,12 +32,11 @@ class EltwiseOp : public Operator<D, T> {
OperatorBase::GetRepeatedArgument<float>("coeff"),
OperatorBase::GetSingleArgument<float>("x", 1.0)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor* input0 = this->Input(0);
const Tensor* input1 = this->InputSize() == 2 ? this->Input(1) : nullptr;
Tensor *output = this->Output(OUTPUT);
functor_(input0, input1, output, future);
return true;
return functor_(input0, input1, output, future);
}
private:
......
......@@ -34,7 +34,7 @@ class FoldedBatchNormOp : public Operator<D, T> {
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *scale = this->Input(SCALE);
const Tensor *offset = this->Input(OFFSET);
......@@ -47,10 +47,9 @@ class FoldedBatchNormOp : public Operator<D, T> {
offset->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
MACE_FAILURE_RETURN(output->ResizeLike(input));
functor_(input, scale, offset, nullptr, nullptr, 0, output, future);
return true;
return functor_(input, scale, offset, nullptr, nullptr, 0, output, future);
}
private:
......
......@@ -33,7 +33,7 @@ class FullyConnectedOp : public Operator<D, T> {
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *weight = this->Input(WEIGHT); // OIHW
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
......@@ -65,8 +65,7 @@ class FullyConnectedOp : public Operator<D, T> {
" don't match.");
}
functor_(input, weight, bias, output, future);
return true;
return functor_(input, weight, bias, output, future);
}
private:
......
......@@ -27,15 +27,14 @@ class ImageToBufferOp : public Operator<D, T> {
ImageToBufferOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
kernels::BufferType type =
static_cast<kernels::BufferType>(OperatorBase::GetSingleArgument<int>(
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
functor_(input, type, output, future);
return true;
return functor_(input, type, output, future);
}
private:
......
......@@ -33,17 +33,16 @@ class LocalResponseNormOp : public Operator<D, T> {
beta_ = OperatorBase::GetSingleArgument<float>("beta", 0.5f);
}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
MACE_FAILURE_RETURN(output->ResizeLike(input));
functor_(input, depth_radius_, bias_, alpha_, beta_, output, future);
return true;
return functor_(input, depth_radius_, bias_, alpha_, beta_, output, future);
}
private:
......
......@@ -27,7 +27,7 @@ class MatMulOp : public Operator<D, T> {
MatMulOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *A = this->Input(0);
const Tensor *B = this->Input(1);
Tensor *C = this->Output(0);
......@@ -38,8 +38,7 @@ class MatMulOp : public Operator<D, T> {
<< "the number of A's column " << A->dim(2)
<< " must be equal to B's row " << B->dim(1);
functor_(A, B, C, future);
return true;
return functor_(A, B, C, future);
}
private:
......
......@@ -354,7 +354,7 @@ class OpsTestNet {
return net_ != nullptr;
}
bool Run() {
MaceStatus Run() {
MACE_CHECK_NOTNULL(net_);
return net_->Run();
}
......@@ -362,7 +362,7 @@ class OpsTestNet {
// DEPRECATED(liyin):
// Test and benchmark should setup model once and run multiple times.
// Setup time should not be counted during benchmark.
bool RunOp(DeviceType device) {
MaceStatus RunOp(DeviceType device) {
Setup(device);
return Run();
}
......@@ -370,16 +370,14 @@ class OpsTestNet {
// DEPRECATED(liyin):
// Test and benchmark should setup model once and run multiple times.
// Setup time should not be counted during benchmark.
bool RunOp() {
MaceStatus RunOp() {
return RunOp(DeviceType::CPU);
}
bool RunNet(const NetDef &net_def, const DeviceType device) {
MaceStatus RunNet(const NetDef &net_def, const DeviceType device) {
device_ = device;
net_ = CreateNet(op_registry_, net_def, &ws_, device, NetMode::INIT);
if (!net_->Run()) {
return false;
}
MACE_FAILURE_RETURN(net_->Run());
net_ = CreateNet(op_registry_, net_def, &ws_, device);
return net_->Run();
}
......
......@@ -32,11 +32,10 @@ class PadOp : public Operator<D, T> {
OperatorBase::GetSingleArgument<float>("constant_value", 0.0))
{}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(0);
Tensor *output_tensor = this->Output(0);
functor_(input_tensor, output_tensor, future);
return true;
return functor_(input_tensor, output_tensor, future);
}
private:
......
......@@ -40,12 +40,11 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
this->paddings_,
this->dilations_.data()) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
functor_(input, output, future);
return true;
return functor_(input, output, future);
};
protected:
......
......@@ -35,15 +35,14 @@ class ProposalOp : public Operator<D, T> {
OperatorBase::GetRepeatedArgument<int>("scales"),
OperatorBase::GetRepeatedArgument<float>("ratios")) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *rpn_cls_prob = this->Input(RPN_CLS_PROB);
const Tensor *rpn_bbox_pred = this->Input(RPN_BBOX_PRED);
const Tensor *img_info = this->Input(IMG_INFO);
Tensor *output = this->Output(ROIS);
functor_(rpn_cls_prob, rpn_bbox_pred, img_info, output, future);
return true;
return functor_(rpn_cls_prob, rpn_bbox_pred, img_info, output, future);
}
private:
......
......@@ -30,14 +30,13 @@ class PSROIAlignOp : public Operator<D, T> {
OperatorBase::GetSingleArgument<int>("output_dim", 0),
OperatorBase::GetSingleArgument<int>("group_size", 0)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *rois = this->Input(ROIS);
Tensor *output = this->Output(OUTPUT);
functor_(input, rois, output, future);
return true;
return functor_(input, rois, output, future);
}
private:
......
......@@ -28,7 +28,7 @@ class QuantizeOp : public Operator<D, T> {
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
......@@ -39,12 +39,11 @@ class QuantizeOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
Tensor *out_min = this->Output(OUT_MIN);
Tensor *out_max = this->Output(OUT_MAX);
output->ResizeLike(input);
out_min->ResizeLike(in_min);
out_max->ResizeLike(in_max);
MACE_FAILURE_RETURN(output->ResizeLike(input));
MACE_FAILURE_RETURN(out_min->ResizeLike(in_min));
MACE_FAILURE_RETURN(out_max->ResizeLike(in_max));
functor_(input, in_min, in_max, output, out_min, out_max, future);
return true;
return functor_(input, in_min, in_max, output, out_min, out_max, future);
}
private:
......@@ -62,7 +61,7 @@ class DequantizeOp : public Operator<D, T> {
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
......@@ -71,10 +70,9 @@ class DequantizeOp : public Operator<D, T> {
MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value");
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
MACE_FAILURE_RETURN(output->ResizeLike(input));
functor_(input, in_min, in_max, output, future);
return true;
return functor_(input, in_min, in_max, output, future);
}
private:
......@@ -92,7 +90,7 @@ class RequantizeOp : public Operator<D, T> {
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
......@@ -114,11 +112,11 @@ class RequantizeOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
Tensor *out_min = this->Output(OUT_MIN);
Tensor *out_max = this->Output(OUT_MAX);
output->ResizeLike(input);
out_min->ResizeLike(in_min);
out_max->ResizeLike(out_max);
MACE_FAILURE_RETURN(output->ResizeLike(input));
MACE_FAILURE_RETURN(out_min->ResizeLike(in_min));
MACE_FAILURE_RETURN(out_max->ResizeLike(out_max));
functor_(input,
return functor_(input,
in_min,
in_max,
rerange_min,
......@@ -127,7 +125,6 @@ class RequantizeOp : public Operator<D, T> {
out_min,
out_max,
future);
return true;
}
private:
......
......@@ -30,7 +30,7 @@ class ReshapeOp : public Operator<D, T> {
: Operator<D, T>(op_def, ws),
shape_(OperatorBase::GetRepeatedArgument<int64_t>("shape")) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const index_t num_dims = shape_.size();
int unknown_idx = -1;
......@@ -61,8 +61,7 @@ class ReshapeOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
functor_(input, out_shape, output, future);
return true;
return functor_(input, out_shape, output, future);
}
private:
......
......@@ -30,15 +30,14 @@ class ResizeBilinearOp : public Operator<D, T> {
OperatorBase::GetRepeatedArgument<index_t>("size", {-1, -1}),
OperatorBase::GetSingleArgument<bool>("align_corners", false)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional.",
input->dim_size());
functor_(input, output, future);
return true;
return functor_(input, output, future);
}
private:
......
......@@ -30,7 +30,7 @@ class SliceOp : public Operator<D, T> {
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
MACE_CHECK(this->OutputSize() >= 2)
<< "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT);
......@@ -39,8 +39,7 @@ class SliceOp : public Operator<D, T> {
MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0)
<< "Outputs do not split input equally.";
functor_(input, output_list, future);
return true;
return functor_(input, output_list, future);
}
private:
......
......@@ -27,14 +27,13 @@ class SoftmaxOp : public Operator<D, T> {
SoftmaxOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *logits = this->Input(LOGITS);
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(logits);
functor_(logits, output, future);
return true;
return functor_(logits, output, future);
}
private:
......
......@@ -34,12 +34,11 @@ class SpaceToBatchNDOp : public Operator<D, T> {
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
false) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *space_tensor = this->Input(INPUT);
Tensor *batch_tensor = this->Output(OUTPUT);
functor_(const_cast<Tensor *>(space_tensor), batch_tensor,
return functor_(const_cast<Tensor *>(space_tensor), batch_tensor,
future);
return true;
}
private:
......
......@@ -32,7 +32,7 @@ class SpaceToDepthOp : public Operator<D, T> {
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), false) {
}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
......@@ -58,8 +58,7 @@ class SpaceToDepthOp : public Operator<D, T> {
(input_width % block_size == 0) && (input_height % block_size == 0),
"input width and height should be dividable by block_size",
input->dim(3));
functor_(input, output, future);
return true;
return functor_(input, output, future);
}
protected:
......
......@@ -31,7 +31,7 @@ class TransposeOp : public Operator<D, T> {
dims_(OperatorBase::GetRepeatedArgument<int>("dims")),
functor_(dims_) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
const std::vector<index_t> &input_shape = input->shape();
......@@ -42,9 +42,8 @@ class TransposeOp : public Operator<D, T> {
for (size_t i = 0; i < dims_.size(); ++i) {
output_shape.push_back(input_shape[dims_[i]]);
}
output->Resize(output_shape);
functor_(input, output, future);
return true;
MACE_FAILURE_RETURN(output->Resize(output_shape));
return functor_(input, output, future);
}
protected:
......
......@@ -38,12 +38,11 @@ class WinogradInverseTransformOp : public Operator<D, T> {
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
const Tensor *bias = this->InputSize() == 2 ? this->Input(BIAS) : nullptr;
Tensor *output_tensor = this->Output(OUTPUT);
functor_(input_tensor, bias, output_tensor, future);
return true;
return functor_(input_tensor, bias, output_tensor, future);
}
private:
......
......@@ -32,12 +32,11 @@ class WinogradTransformOp : public Operator<D, T> {
"padding", static_cast<int>(VALID))),
OperatorBase::GetRepeatedArgument<int>("padding_values")) {}
bool Run(StatsFuture *future) override {
MaceStatus Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
Tensor *output_tensor = this->Output(OUTPUT);
functor_(input_tensor, output_tensor, future);
return true;
return functor_(input_tensor, output_tensor, future);
}
private:
......
......@@ -65,6 +65,15 @@ enum MaceStatus {
MACE_OUT_OF_RESOURCES = 2
};
#define MACE_FAILURE_RETURN(stmt) \
{ \
MaceStatus status = (stmt); \
if (status != MACE_SUCCESS) { \
VLOG(0) << "Mace runtime failure: " << __FILE__ << ":" << __LINE__; \
return status; \
} \
}
// MACE input/output tensor
class MaceTensor {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册