提交 aa52f08f 编写于 作者: 叶剑武

support fallback from opencl to cpu in ReshapeOp

上级 6a231fdb
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
#include "mace/utils/math.h" #include "mace/utils/math.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/reshape.h"
#include "mace/ops/opencl/buffer/reshape.h" #include "mace/ops/opencl/buffer/reshape.h"
#include "mace/ops/opencl/image/reshape.h"
#endif #endif
namespace mace { namespace mace {
...@@ -46,8 +46,7 @@ MaceStatus GetOutputShape(const Tensor *input, ...@@ -46,8 +46,7 @@ MaceStatus GetOutputShape(const Tensor *input,
MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
shape_data[i]); shape_data[i]);
if (shape_data[i] == 0) { if (shape_data[i] == 0) {
MACE_CHECK(i < input->dim_size(), MACE_CHECK(i < input->dim_size(), "dims:0 out of input dims' range.");
"dims:0 out of input dims' range.");
n = input->dim(i); n = input->dim(i);
} else { } else {
n = shape_data[i]; n = shape_data[i];
...@@ -59,10 +58,10 @@ MaceStatus GetOutputShape(const Tensor *input, ...@@ -59,10 +58,10 @@ MaceStatus GetOutputShape(const Tensor *input,
if (unknown_idx != -1) { if (unknown_idx != -1) {
MACE_CHECK(product != 0) MACE_CHECK(product != 0)
<< "Cannot infer shape if there is zero shape size."; << "Cannot infer shape if there is zero shape size.";
const index_t missing = input->size() / product; const index_t missing = input->size() / product;
MACE_CHECK(missing * product == input->size()) MACE_CHECK(missing * product == input->size())
<< "Input size not match reshaped tensor size"; << "Input size not match reshaped tensor size";
(*out_shape)[unknown_idx] = missing; (*out_shape)[unknown_idx] = missing;
} }
...@@ -71,7 +70,7 @@ MaceStatus GetOutputShape(const Tensor *input, ...@@ -71,7 +70,7 @@ MaceStatus GetOutputShape(const Tensor *input,
} // namespace } // namespace
template<DeviceType D, class T> template <DeviceType D, class T>
class ReshapeOp : public Operation { class ReshapeOp : public Operation {
public: public:
explicit ReshapeOp(OpConstructContext *context) explicit ReshapeOp(OpConstructContext *context)
...@@ -90,11 +89,11 @@ class ReshapeOp : public Operation { ...@@ -90,11 +89,11 @@ class ReshapeOp : public Operation {
GetOutputShape(input, shape_data, num_dims, &out_shape)); GetOutputShape(input, shape_data, num_dims, &out_shape));
// NHWC -> NCHW // NHWC -> NCHW
if (has_df_ && D == DeviceType::CPU if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 &&
&& out_shape.size() == 4 && shape->is_weight()) { shape->is_weight()) {
std::vector<int> dst_dims = {0, 3, 1, 2}; std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> trans_shape = TransposeShape<index_t, index_t>( std::vector<index_t> trans_shape =
out_shape, dst_dims); TransposeShape<index_t, index_t>(out_shape, dst_dims);
out_shape = trans_shape; out_shape = trans_shape;
} }
...@@ -114,12 +113,11 @@ class ReshapeOp : public Operation { ...@@ -114,12 +113,11 @@ class ReshapeOp : public Operation {
}; };
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
template<> template <>
class ReshapeOp<GPU, float> : public Operation { class ReshapeOp<GPU, float> : public Operation {
public: public:
explicit ReshapeOp(OpConstructContext *context) explicit ReshapeOp(OpConstructContext *context)
: Operation(context), : Operation(context), dim_(Operation::GetRepeatedArgs<int>("dim")) {
dim_(Operation::GetRepeatedArgs<int>("dim")) {
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ReshapeKernel>(context); kernel_ = make_unique<opencl::image::ReshapeKernel>(context);
} else { } else {
...@@ -148,11 +146,25 @@ class ReshapeOp<GPU, float> : public Operation { ...@@ -148,11 +146,25 @@ class ReshapeOp<GPU, float> : public Operation {
#endif #endif
void RegisterReshape(OpRegistryBase *op_registry) { void RegisterReshape(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, float);
DeviceType::CPU, float); MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, int32_t);
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp,
DeviceType::CPU, int32_t);
MACE_REGISTER_GPU_OP(op_registry, "Reshape", ReshapeOp); MACE_REGISTER_GPU_OP(op_registry, "Reshape", ReshapeOp);
MACE_REGISTER_OP_CONDITION(
op_registry, OpConditionBuilder("Reshape").SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
auto tensor_shape_info = context->tensor_shape_info();
const std::string &input_0 = op->input(0);
if (4 == op->output_shape(0).dims_size() &&
4 == tensor_shape_info->at(input_0).size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
}));
} }
} // namespace ops } // namespace ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册