diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 46966284274e50f79694b4baf074e95ef2166061..c7401fcdd5bb898d21979baad70b89e7beae4e81 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -124,6 +124,7 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, tensor_map_[const_tensor.name()] = std::move(tensor); } + fused_buffer_ = false; } else { #else { @@ -165,6 +166,7 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def, tensor->SetZeroPoint(const_tensor.zero_point()); tensor_map_[const_tensor.name()] = std::move(tensor); } + fused_buffer_ = true; } } @@ -327,7 +329,34 @@ void Workspace::RemoveUnusedBuffer() { tensor_map_.erase(old_iter); } } + tensor_buffer_.reset(nullptr); +} + +void Workspace::RemoveAndReloadBuffer(const NetDef &net_def, + const unsigned char *model_data) { + for (auto &const_tensor : net_def.tensors()) { + auto iter = tensor_map_.find(const_tensor.name()); + if (iter->second->unused()) { + tensor_map_.erase(iter); + } else if (fused_buffer_) { + tensor_map_.erase(iter); + std::vector dims; + for (const index_t d : const_tensor.dims()) { + dims.push_back(d); + } + std::unique_ptr tensor( + new Tensor(GetDeviceAllocator(DeviceType::GPU), + const_tensor.data_type())); + tensor->Resize(dims); + MACE_CHECK(tensor->size() == const_tensor.data_size(), + "Tensor's data_size not equal with the shape"); + tensor->CopyBytes(model_data + const_tensor.offset(), + const_tensor.data_size() * + GetEnumTypeSize(const_tensor.data_type())); + tensor_map_[const_tensor.name()] = std::move(tensor); + } + } tensor_buffer_.reset(nullptr); } diff --git a/mace/core/workspace.h b/mace/core/workspace.h index ec636128e2883e942bef7ab68a9995efbf8e6be5..20f214b0018a93b59b84d8bf4cae7004e4e6ba0d 100644 --- a/mace/core/workspace.h +++ b/mace/core/workspace.h @@ -55,6 +55,9 @@ class Workspace { void RemoveUnusedBuffer(); + void RemoveAndReloadBuffer(const NetDef &net_def, + const unsigned char *model_data); + private: MaceStatus CreateOutputTensorBuffer(const NetDef &net_def, DeviceType device_type); @@ -66,6 +69,7 @@ class Workspace { PreallocatedPooledAllocator preallocated_allocator_; std::unique_ptr host_scratch_buffer_; + bool fused_buffer_; MACE_DISABLE_COPY_AND_ASSIGN(Workspace); }; diff --git a/mace/kernels/deconv_2d.h b/mace/kernels/deconv_2d.h index 3369f4f4619b7541c4f1203cb2144f01e3642985..ad527a84410fe8e077914df9515c38fcaff03d22 100644 --- a/mace/kernels/deconv_2d.h +++ b/mace/kernels/deconv_2d.h @@ -174,15 +174,15 @@ struct Deconv2dFunctorBase { switch (padding) { case VALID: expected_input_height = - (out_height - filter_h) / strides[0] + 1; + (out_height - filter_h + strides[0]) / strides[0]; expected_input_width = - (out_width - filter_w) / strides[1] + 1; + (out_width - filter_w + strides[1]) / strides[1]; break; case SAME: expected_input_height = - (out_height - 1) / strides[0] + 1; + (out_height + strides[0] - 1) / strides[0]; expected_input_width = - (out_width - 1) / strides[1] + 1; + (out_width + strides[1] - 1) / strides[1]; break; default: MACE_CHECK(false, "Unsupported padding type: ", padding); diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index a246846b6d7283d3e8de01a452d7ad000da00c99..42d220fa2dd5a7c6bbd39052dd6a99960d24cda5 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -805,13 +805,19 @@ inline void TensorEltwisePerChannel(const EltwiseType type, struct EltwiseFunctorBase { EltwiseFunctorBase(const EltwiseType type, const std::vector &coeff, - const float value, + const float scalar_input, + const int32_t scalar_input_index, const DataFormat data_format) - : type_(type), coeff_(coeff), value_(value), data_format_(data_format) {} + : type_(type), + coeff_(coeff), + scalar_input_(scalar_input), + scalar_input_index_(scalar_input_index), + data_format_(data_format) {} EltwiseType type_; std::vector coeff_; - float value_; + float scalar_input_; + int32_t scalar_input_index_; DataFormat data_format_; }; @@ -819,9 +825,14 @@ template struct EltwiseFunctor : EltwiseFunctorBase { EltwiseFunctor(const EltwiseType type, const std::vector &coeff, - const float value, // keep it float as it comes from arg + const float scalar_input, // float as it comes from arg + const int32_t scalar_input_index, const DataFormat data_format) - : EltwiseFunctorBase(type, coeff, value, data_format) {} + : EltwiseFunctorBase(type, + coeff, + scalar_input, + scalar_input_index, + data_format) {} template MaceStatus DoEltwise(const Tensor *input0, @@ -832,6 +843,9 @@ struct EltwiseFunctor : EltwiseFunctorBase { std::swap(input0, input1); swapped = true; } + if (scalar_input_index_ == 0) { + swapped = !swapped; + } // check if we can broadcast tensor uint32_t rank_diff = @@ -924,7 +938,7 @@ struct EltwiseFunctor : EltwiseFunctorBase { scalar_tensor_.Resize({}); Tensor::MappingGuard guard(&scalar_tensor_); auto scalar_data = scalar_tensor_.mutable_data(); - scalar_data[0] = static_cast(value_); + scalar_data[0] = static_cast(scalar_input_); input1 = &scalar_tensor_; } @@ -944,9 +958,14 @@ template struct EltwiseFunctor : EltwiseFunctorBase { EltwiseFunctor(const EltwiseType type, const std::vector &coeff, - const float value, + const float scalar_input, + const int32_t scalar_input_index, const DataFormat data_format) - : EltwiseFunctorBase(type, coeff, value, data_format) {} + : EltwiseFunctorBase(type, + coeff, + scalar_input, + scalar_input_index, + data_format) {} MaceStatus operator()(const Tensor *input0, const Tensor *input1, diff --git a/mace/kernels/opencl/deconv_2d.cc b/mace/kernels/opencl/deconv_2d.cc index 80e6370d751af8cd6f5e26fa6b5b8f0046b54592..770d64efc06d1cf34b3bb178ac3d0e4cdce87709 100644 --- a/mace/kernels/opencl/deconv_2d.cc +++ b/mace/kernels/opencl/deconv_2d.cc @@ -152,7 +152,6 @@ MaceStatus Deconv2dFunctor::operator()( MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(filter); MACE_CHECK_NOTNULL(output); - if (!from_caffe_) { if (output_shape_.size() != 4) { MACE_CHECK_NOTNULL(output_shape_tensor); @@ -174,7 +173,6 @@ MaceStatus Deconv2dFunctor::operator()( CalcDeconvOutputSize(input->shape().data(), filter->shape().data(), strides_, output_shape_.data(), paddings_.data()); } - std::vector output_image_shape; CalImage2DShape(output_shape_, BufferType::IN_OUT_CHANNEL, &output_image_shape); diff --git a/mace/kernels/opencl/eltwise.cc b/mace/kernels/opencl/eltwise.cc index 1f9eebe35702c9cd713b82e50d7a8abcbc43f830..9eedf011009ee8acbb97a00a84da5e140f11fa8c 100644 --- a/mace/kernels/opencl/eltwise.cc +++ b/mace/kernels/opencl/eltwise.cc @@ -48,6 +48,10 @@ MaceStatus EltwiseFunctor::operator()(const Tensor *input0, } } + if (scalar_input_index_ == 0) { + swapped = !swapped; + } + std::vector output_shape(4); output_shape[0] = input0->dim(0); output_shape[1] = input0->dim(1); @@ -104,7 +108,7 @@ MaceStatus EltwiseFunctor::operator()(const Tensor *input0, SET_3D_GWS_ARGS(kernel_); kernel_.setArg(idx++, *(input0->opencl_image())); if (input1 == nullptr) { - kernel_.setArg(idx++, value_); + kernel_.setArg(idx++, scalar_input_); } else { kernel_.setArg(idx++, *(input1->opencl_image())); } diff --git a/mace/kernels/scalar_math.h b/mace/kernels/scalar_math.h new file mode 100644 index 0000000000000000000000000000000000000000..604302074cf30e6b03c1b0c2ded96b2596696b62 --- /dev/null +++ b/mace/kernels/scalar_math.h @@ -0,0 +1,158 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_SCALAR_MATH_H_ +#define MACE_KERNELS_SCALAR_MATH_H_ + +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" +#include "mace/kernels/eltwise.h" + +namespace mace { +namespace kernels { + +template +void ScalarEltwise(const T* in0, + const T* in1, + const EltwiseType type, + const std::vector &coeff, + const bool swapped, + DstType* out) { + switch (type) { + case SUM: + if (coeff.empty()) { + out[0] = in0[0] + in1[0]; + } else { + MACE_CHECK(coeff.size() == 2, + "sum's coeff params' size should be 2."); + if (swapped) + out[0] = in0[0] * coeff[1] + in1[0] * coeff[0]; + else + out[0] = in0[0] * coeff[0] + in1[0] * coeff[1]; + } + break; + case SUB: + if (swapped) + out[0] = in1[0] - in0[0]; + else + out[0] = in0[0] - in1[0]; + break; + case PROD: + out[0] = in0[0] * in1[0]; + break; + case DIV: + if (swapped) + out[0] = in1[0] / in0[0]; + else + out[0] = in0[0] / in1[0]; + break; + case MIN: + out[0] = std::min(in1[0], in0[0]); + break; + case MAX: + out[0] = std::max(in1[0], in0[0]); + break; + case SQR_DIFF: + out[0] = std::pow(in1[0] - in0[0], 2.f); + break; + case POW: + out[0] = std::pow(in0[0], in1[0]); + break; + case EQUAL: + out[0] = in1[0] == in0[0]; + break; + case NEG: + out[0] = -in0[0]; + break; + case ABS: + out[0] = in0[0] > 0 ? in0[0] : -in0[0]; + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type; + } +} + + +template +struct ScalarMathFunctor { + explicit ScalarMathFunctor(const EltwiseType type, + const std::vector &coeff, + const float scalar_input, + const int32_t scalar_input_index) + : type_(type), + coeff_(coeff), + scalar_input_(scalar_input), + scalar_input_index_(scalar_input_index) {} + + MaceStatus operator()(const std::vector &inputs, + Tensor *output, + StatsFuture *future) { + const Tensor* input0 = inputs[0]; + const Tensor* input1 = (inputs.size() >= 2) ? inputs[1] : nullptr; + MACE_CHECK(input0->dim_size() <= 1 && input0->size() == 1, + "not support input dim size") << input0->dim_size(); + Tensor::MappingGuard in0_guard(input0); + const T* in0 = input0->data(); + auto v = static_cast(scalar_input_); + const T* in1 = &v; + Tensor::MappingGuard in1_guard(input1); + if (input1) { + MACE_CHECK(input1->dim_size() == 0); + in1 = input1->data(); + } + if (input0->dim_size() > 0) { + MACE_RETURN_IF_ERROR(output->Resize(input0->shape())); + } else { + output->Resize({}); + } + + Tensor::MappingGuard output_guard(output); + bool swapped = scalar_input_index_ == 0; + + if (IsLogicalType(type_)) { + int32_t* out = output->mutable_data(); + ScalarEltwise(in0, + in1, + type_, + coeff_, + swapped, + out); + } else { + T* out = output->mutable_data(); + ScalarEltwise(in0, + in1, + type_, + coeff_, + swapped, + out); + } + + SetFutureDefaultWaitFn(future); + return MACE_SUCCESS; + } + + EltwiseType type_; + std::vector coeff_; + float scalar_input_; + int32_t scalar_input_index_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_SCALAR_MATH_H_ diff --git a/mace/kernels/stack.h b/mace/kernels/stack.h index 3a630d8f28caa18e7950a22f8835a78305ffc79e..9a84bed0a4d5fc41670aa4d7c5cdae4aafb9544b 100644 --- a/mace/kernels/stack.h +++ b/mace/kernels/stack.h @@ -46,7 +46,13 @@ struct StackFunctor { output_shape.insert(output_shape.begin() + axis_, inputs.size()); MACE_RETURN_IF_ERROR(output->Resize(output_shape)); - // On host, no need to map data + // Some inputs may be in gpu memory, so add mapping here. + std::vector mappers; + for (size_t i = 0; i < inputs.size(); ++i) { + mappers.emplace_back(Tensor::MappingGuard(inputs[i])); + } + + // Output is on host, no need to map data T *output_data = output->mutable_data(); std::vector input_data(inputs.size()); for (size_t i = 0; i < inputs.size(); ++i) { diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h index 20c508aa73109fe5124d18804938a29d578720a0..a6afb46c56cd2e500899197836b5803583dc6c06 100644 --- a/mace/kernels/strided_slice.h +++ b/mace/kernels/strided_slice.h @@ -51,7 +51,6 @@ struct StridedSliceFunctor { StatsFuture *future) { MACE_CHECK(ellipsis_mask_ == 0 && new_axis_mask_ == 0, "ellipsis_mask and new_axis_mask are not supported yet."); - if (strides == nullptr) { tmp_strides_tensor_.Resize({begin_indices->size()}); Tensor::MappingGuard strides_guard(&tmp_strides_tensor_); @@ -68,7 +67,6 @@ struct StridedSliceFunctor { const int32_t *begin_indices_data = begin_indices->data(); const int32_t *end_indices_data = end_indices->data(); const int32_t *strides_data = strides->data(); - std::vector pad_begin_indices(input->dim_size(), 0); std::vector pad_end_indices(input->dim_size(), 0); std::vector pad_strides_indices(input->dim_size(), 1); diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index 470d082f7ba4b8c82d56f232d6433789e60d4e8c..ae7edef3d101b1376d8a338492bd70a2e71f3868 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -267,7 +267,7 @@ MaceStatus MaceEngine::Impl::Init( } #endif if (device_type_ == DeviceType::GPU) { - ws_->RemoveUnusedBuffer(); + ws_->RemoveAndReloadBuffer(*net_def, model_data); } return MaceStatus::MACE_SUCCESS; } diff --git a/mace/ops/eltwise.h b/mace/ops/eltwise.h index 8c88e9a2d9c968ceaefa9baa3c5e179deb3fce7f..161d0e4fd9b5dba81d5d9d504ecd4a608edbedd4 100644 --- a/mace/ops/eltwise.h +++ b/mace/ops/eltwise.h @@ -30,7 +30,8 @@ class EltwiseOp : public Operator { static_cast(OperatorBase::GetOptionalArg( "type", static_cast(kernels::EltwiseType::NONE))), OperatorBase::GetRepeatedArgs("coeff"), - OperatorBase::GetOptionalArg("value", 1.0), + OperatorBase::GetOptionalArg("scalar_input", 1.0), + OperatorBase::GetOptionalArg("scalar_input_index", 1), static_cast(OperatorBase::GetOptionalArg( "data_format", 0))) {} diff --git a/mace/ops/eltwise_test.cc b/mace/ops/eltwise_test.cc index ddf113d8f9dd3ef6db847af24332f1d5eb35b918..55a0ce977563c16e8e3914a5c435fed7cedbbabc 100644 --- a/mace/ops/eltwise_test.cc +++ b/mace/ops/eltwise_test.cc @@ -39,7 +39,7 @@ void SimpleScalarScalar(const kernels::EltwiseType type, .Input("Input") .AddIntArg("T", DataTypeToEnum::v()) .AddIntArg("type", static_cast(type)) - .AddFloatArg("value", x) + .AddFloatArg("scalar_input", x) .OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -72,7 +72,7 @@ void SimpleTensorScalar(const kernels::EltwiseType type, .Input("TInput") .AddIntArg("T", DataTypeToEnum::v()) .AddIntArg("type", static_cast(type)) - .AddFloatArg("value", x) + .AddFloatArg("scalar_input", x) .AddIntArg("data_format", DataFormat::NCHW) .OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) .Output("TOutput") @@ -86,7 +86,7 @@ void SimpleTensorScalar(const kernels::EltwiseType type, OpDefBuilder("Eltwise", "EltwiseTest") .Input("InputImg") .AddIntArg("type", static_cast(type)) - .AddFloatArg("value", x) + .AddFloatArg("scalar_input", x) .Output("OutputImg") .Finalize(net.NewOperatorDef()); @@ -468,7 +468,7 @@ void RandomTensorScalar(const kernels::EltwiseType type, OpDefBuilder("Eltwise", "EltwiseTest") .Input("TInput") .AddIntArg("type", static_cast(type)) - .AddFloatArg("value", 0.1) + .AddFloatArg("scalar_input", 0.1) .AddIntArg("data_format", DataFormat::NCHW) .Output("TOutput") .Finalize(net.NewOperatorDef()); @@ -484,7 +484,7 @@ void RandomTensorScalar(const kernels::EltwiseType type, OpDefBuilder("Eltwise", "EltwiseTest") .Input("InputImg") .AddIntArg("type", static_cast(type)) - .AddFloatArg("value", 0.1) + .AddFloatArg("scalar_input", 0.1) .Output("OutputImg") .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); diff --git a/mace/ops/ops_register.cc b/mace/ops/ops_register.cc index 3afe66c9c1c408993a18dfde19d8f1e63ba920fa..c318eb4417165ecfe6a2aa49339af3fa98093964 100644 --- a/mace/ops/ops_register.cc +++ b/mace/ops/ops_register.cc @@ -48,6 +48,7 @@ extern void Register_Quantize(OperatorRegistryBase *op_registry); extern void Register_ReduceMean(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); +extern void Register_ScalarMath(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry); extern void Register_Split(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry); @@ -99,6 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ops::Register_ReduceMean(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); + ops::Register_ScalarMath(this); ops::Register_Shape(this); ops::Register_Split(this); ops::Register_Softmax(this); diff --git a/mace/ops/scalar_math.cc b/mace/ops/scalar_math.cc new file mode 100644 index 0000000000000000000000000000000000000000..82ef3eb3b3205adb45fb89f5d30c77af47355921 --- /dev/null +++ b/mace/ops/scalar_math.cc @@ -0,0 +1,44 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/scalar_math.h" + +namespace mace { +namespace ops { + +void Register_ScalarMath(OperatorRegistryBase *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ScalarMathOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ScalarMathOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ScalarMathOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + ScalarMathOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/scalar_math.h b/mace/ops/scalar_math.h new file mode 100644 index 0000000000000000000000000000000000000000..29cb478c718f0d7eef1a8c1e18c61550ca9f2cee --- /dev/null +++ b/mace/ops/scalar_math.h @@ -0,0 +1,52 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_SCALAR_MATH_H_ +#define MACE_OPS_SCALAR_MATH_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/scalar_math.h" + +namespace mace { +namespace ops { + +template +class ScalarMathOp : public Operator { + public: + ScalarMathOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(static_cast( + OperatorBase::GetOptionalArg( + "type", static_cast(kernels::EltwiseType::NONE))), + OperatorBase::GetRepeatedArgs("coeff"), + OperatorBase::GetOptionalArg("scalar_input", 1.0), + OperatorBase::GetOptionalArg( + "scalar_input_index", 1)) {} + + MaceStatus Run(StatsFuture *future) override { + const std::vector input_list = this->Inputs(); + Tensor *output = this->Output(0); + return functor_(input_list, output, future); + } + + private: + kernels::ScalarMathFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_SCALAR_MATH_H_ diff --git a/mace/ops/scalar_math_test.cc b/mace/ops/scalar_math_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..32b9db0001f4c9edb5639e90683bb5ac49a3449d --- /dev/null +++ b/mace/ops/scalar_math_test.cc @@ -0,0 +1,109 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "mace/kernels/eltwise.h" + +namespace mace { +namespace ops { +namespace test { + +class ScalarMathOpTest : public OpsTestBase {}; + +namespace { +template +void ScalarMathTest(const kernels::EltwiseType type, + const T input0, + const T input1, + const float x, + const DstType output) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input0", {}, {input0}); + net.AddInputFromArray("Input1", {}, {input1}); + + OpDefBuilder("ScalarMath", "ScalarMathTest") + .Input("Input0") + .Input("Input1") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(type)) + .AddFloatArg("scalar_input", x) + .OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + + + auto expected = CreateTensor({}, {output}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} +} // namespace + +TEST_F(ScalarMathOpTest, SimpleCPU) { +ScalarMathTest( + kernels::EltwiseType::SUM, 1, 2, 3, 3); +ScalarMathTest( + kernels::EltwiseType::SUB, 1, 2, 3, -1); +ScalarMathTest( + kernels::EltwiseType::PROD, 3, -2, 3, -6); +ScalarMathTest( + kernels::EltwiseType::DIV, 3, -2, 1, -1.5); +ScalarMathTest( + kernels::EltwiseType::MIN, 3, -2, 1, -2); +ScalarMathTest( + kernels::EltwiseType::MAX, 3, -2, 1, 3); +ScalarMathTest( + kernels::EltwiseType::NEG, 3, -2, 1, -3); +ScalarMathTest( + kernels::EltwiseType::ABS, 3, -2, 1, 3); +ScalarMathTest( + kernels::EltwiseType::SQR_DIFF, 3, -2, 1, 25); +ScalarMathTest( + kernels::EltwiseType::POW, 3, 1, 1, 3); +ScalarMathTest( + kernels::EltwiseType::EQUAL, 3, 3, 1, 1); +} + +TEST_F(ScalarMathOpTest, SimpleGPU) { +ScalarMathTest( + kernels::EltwiseType::SUM, 1, 2, 1, 3); +ScalarMathTest( + kernels::EltwiseType::SUB, 1, 2, 1, -1); +ScalarMathTest( + kernels::EltwiseType::PROD, 3, -2, 1, -6); +ScalarMathTest( + kernels::EltwiseType::DIV, 3, -2, 1, -1.5); +ScalarMathTest( + kernels::EltwiseType::MIN, 3, -2, 1, -2); +ScalarMathTest( + kernels::EltwiseType::MAX, 3, -2, 1, 3); +ScalarMathTest( + kernels::EltwiseType::NEG, 3, -2, 1, -3); +ScalarMathTest( + kernels::EltwiseType::ABS, 3, -2, 1, 3); +ScalarMathTest( + kernels::EltwiseType::SQR_DIFF, 3, -2, 1, 25); +ScalarMathTest( + kernels::EltwiseType::POW, 3, 1, 1, 3); +ScalarMathTest( + kernels::EltwiseType::EQUAL, 3, 3, 1, 1); +} +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/stack.cc b/mace/ops/stack.cc index 968f859d5945d8d353d56cf631f433e409c22f54..7aa7c07eb407e35b36170c0b7784f001297415f1 100644 --- a/mace/ops/stack.cc +++ b/mace/ops/stack.cc @@ -28,6 +28,16 @@ void Register_Stack(OperatorRegistryBase *op_registry) { .TypeConstraint("T") .Build(), StackOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + StackOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Stack") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + StackOp); } } // namespace ops diff --git a/mace/ops/strided_slice.cc b/mace/ops/strided_slice.cc index b449be038f33e34d03b3af9360634513f852f544..0f608b1722fc60603f4e6f0e1d95d9f6e57e1e69 100644 --- a/mace/ops/strided_slice.cc +++ b/mace/ops/strided_slice.cc @@ -28,6 +28,16 @@ void Register_StridedSlice(OperatorRegistryBase *op_registry) { .TypeConstraint("T") .Build(), StridedSliceOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + StridedSliceOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("StridedSlice") + .Device(DeviceType::GPU) + .TypeConstraint("T") + .Build(), + StridedSliceOp); } } // namespace ops diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index cfd8409abfaad1de1b066583665afd9f47d4af41..e31ed5b833a1b2db0457dbf812f4e636a9c37445 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -101,6 +101,7 @@ MaceSupportedOps = [ 'ReduceMean', 'Reshape', 'ResizeBilinear', + 'ScalarMath', 'Slice', 'Split', 'Shape', @@ -153,7 +154,7 @@ class MaceKeyword(object): mace_shape_str = 'shape' mace_winograd_filter_transformed = 'is_filter_transformed' mace_device = 'device' - mace_value_str = 'value' + mace_scalar_input_str = 'scalar_input' mace_wino_block_size = 'wino_block_size' mace_output_shape_str = 'output_shape' mace_begin_mask_str = 'begin_mask' @@ -167,6 +168,8 @@ class MaceKeyword(object): mace_offset_str = 'offset' mace_from_caffe_str = 'from_caffe' mace_opencl_max_image_size = "opencl_max_image_size" + mace_seperate_buffer_str = 'seperate_buffer' + mace_scalar_input_index_str = 'scalar_input_index' class TransformerRule(Enum): diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 9583d0e163be75b3a0e92afdfcb5a994bc59c5a6..da9384fa93affdca98ad05983055e35a6771b0e3 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -401,13 +401,24 @@ class TensorflowConverter(base_converter.ConverterInterface): type_arg.name = MaceKeyword.mace_element_type_str type_arg.i = self.eltwise_type[tf_op.type].value + def check_is_scalar(tf_op): + if len(tf_op.inputs) == 1: + return len(tf_op.inputs[0].shape) == 0 + elif len(tf_op.inputs) == 2: + return len(tf_op.inputs[0].shape) == 0 and\ + len(tf_op.inputs[1].shape) == 0 + + if check_is_scalar(tf_op): + op.type = MaceOp.ScalarMath.name + else: + op.type = MaceOp.Eltwise.name if tf_op.type == TFOpType.Square: value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_value_str + value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.f = 2.0 elif tf_op.type == TFOpType.Rsqrt: value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_value_str + value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.f = -0.5 if type_arg.i != EltwiseType.NEG.value \ @@ -418,19 +429,31 @@ class TensorflowConverter(base_converter.ConverterInterface): EltwiseType.SUM, EltwiseType.PROD, EltwiseType.MAX, EltwiseType.MIN] - if len(tf_op.inputs) > 1 and len(tf_op.inputs[1].shape) == 0: + if len(tf_op.inputs) > 1 and\ + len(tf_op.inputs[1].shape) == 0 and\ + tf_op.inputs[1].op.type == TFOpType.Const.name: scalar = tf_op.inputs[1].eval().astype(np.float32) value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_value_str + value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.f = scalar self._skip_tensor.add(tf_op.inputs[1].name) + value_index_arg = op.arg.add() + value_index_arg.name =\ + MaceKeyword.mace_scalar_input_index_str + value_index_arg.i = 1 + self._skip_tensor.add(tf_op.inputs[1].name) del op.input[1] - elif len(tf_op.inputs[0].shape) == 0 and \ + elif len(tf_op.inputs[0].shape) == 0 and\ + tf_op.inputs[0].op.type == TFOpType.Const.name and\ is_commutative(type_arg.i): scalar = tf_op.inputs[0].eval().astype(np.float32) value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_value_str + value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.f = scalar + value_index_arg = op.arg.add() + value_index_arg.name =\ + MaceKeyword.mace_scalar_input_index_str + value_index_arg.i = 0 self._skip_tensor.add(tf_op.inputs[0].name) del op.input[0] except tf.errors.InvalidArgumentError: @@ -771,7 +794,6 @@ class TensorflowConverter(base_converter.ConverterInterface): def convert_split(self, tf_op): axis = tf_op.inputs[0].eval().astype(np.int32) axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis - input_shape = self.infer_tensor_shape(tf_op.inputs[1]) op = self.convert_general_op(tf_op) op.type = MaceOp.Split.name del op.input[0] diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 5f179d653f25e5773587d178b6668d2fd0641356..640992e20f32a366ed84dc5a8989e59de932afa2 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -117,7 +117,6 @@ class Transformer(base_converter.ConverterInterface): changed = transformer() if not changed: break - return self._model def filter_format(self): diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 36ee96074121ec009b3bd9032f6d15fffd5c5a5d..c0f1ddd022671e8b4db291301abc1b08aa4fe255 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -228,16 +228,24 @@ class GPUMemoryOptimizer(MemoryOptimizer): mace_pb2.GPU_IMAGE, calculate_image_shape(OpenCLBufferType.IN_OUT_HEIGHT, buffer_shape)) - elif op_type == 'Shape': - mem_block = MemoryBlock(mace_pb2.CPU_BUFFER, - [output_shape[0], 1]) + elif op_type in ['Shape', 'StridedSlice', 'Stack', 'ScalarMath']: + if len(output_shape) == 1: + mem_block = MemoryBlock(mace_pb2.CPU_BUFFER, + [output_shape[0], 1]) + elif len(output_shape) == 0: + mem_block = MemoryBlock(mace_pb2.CPU_BUFFER, + [1, 1]) + else: + raise Exception('%s output shape dim size is not 0 or 1.' % + op_type) else: if len(output_shape) == 2: # only support fc/softmax buffer_shape = [output_shape[0], 1, 1, output_shape[1]] elif len(output_shape) == 4: buffer_shape = output_shape else: - raise Exception('output shape dim size is not 2 or 4.') + raise Exception('%s output shape dim size is not 2 or 4.' % + op_type) mem_block = MemoryBlock( mace_pb2.GPU_IMAGE, calculate_image_shape(OpenCLBufferType.IN_OUT_CHANNEL,