提交 aa52f08f 编写于 作者: 叶剑武

support fallback from opencl to cpu in ReshapeOp

上级 6a231fdb
......@@ -18,8 +18,8 @@
#include "mace/utils/math.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/reshape.h"
#include "mace/ops/opencl/buffer/reshape.h"
#include "mace/ops/opencl/image/reshape.h"
#endif
namespace mace {
......@@ -46,8 +46,7 @@ MaceStatus GetOutputShape(const Tensor *input,
MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
shape_data[i]);
if (shape_data[i] == 0) {
MACE_CHECK(i < input->dim_size(),
"dims:0 out of input dims' range.");
MACE_CHECK(i < input->dim_size(), "dims:0 out of input dims' range.");
n = input->dim(i);
} else {
n = shape_data[i];
......@@ -59,10 +58,10 @@ MaceStatus GetOutputShape(const Tensor *input,
if (unknown_idx != -1) {
MACE_CHECK(product != 0)
<< "Cannot infer shape if there is zero shape size.";
<< "Cannot infer shape if there is zero shape size.";
const index_t missing = input->size() / product;
MACE_CHECK(missing * product == input->size())
<< "Input size not match reshaped tensor size";
<< "Input size not match reshaped tensor size";
(*out_shape)[unknown_idx] = missing;
}
......@@ -71,7 +70,7 @@ MaceStatus GetOutputShape(const Tensor *input,
} // namespace
template<DeviceType D, class T>
template <DeviceType D, class T>
class ReshapeOp : public Operation {
public:
explicit ReshapeOp(OpConstructContext *context)
......@@ -90,11 +89,11 @@ class ReshapeOp : public Operation {
GetOutputShape(input, shape_data, num_dims, &out_shape));
// NHWC -> NCHW
if (has_df_ && D == DeviceType::CPU
&& out_shape.size() == 4 && shape->is_weight()) {
if (has_df_ && D == DeviceType::CPU && out_shape.size() == 4 &&
shape->is_weight()) {
std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> trans_shape = TransposeShape<index_t, index_t>(
out_shape, dst_dims);
std::vector<index_t> trans_shape =
TransposeShape<index_t, index_t>(out_shape, dst_dims);
out_shape = trans_shape;
}
......@@ -114,12 +113,11 @@ class ReshapeOp : public Operation {
};
#ifdef MACE_ENABLE_OPENCL
template<>
template <>
class ReshapeOp<GPU, float> : public Operation {
public:
explicit ReshapeOp(OpConstructContext *context)
: Operation(context),
dim_(Operation::GetRepeatedArgs<int>("dim")) {
: Operation(context), dim_(Operation::GetRepeatedArgs<int>("dim")) {
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ReshapeKernel>(context);
} else {
......@@ -148,11 +146,25 @@ class ReshapeOp<GPU, float> : public Operation {
#endif
void RegisterReshape(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp,
DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp,
DeviceType::CPU, int32_t);
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp, DeviceType::CPU, int32_t);
MACE_REGISTER_GPU_OP(op_registry, "Reshape", ReshapeOp);
MACE_REGISTER_OP_CONDITION(
op_registry, OpConditionBuilder("Reshape").SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
auto tensor_shape_info = context->tensor_shape_info();
const std::string &input_0 = op->input(0);
if (4 == op->output_shape(0).dims_size() &&
4 == tensor_shape_info->at(input_0).size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
}));
}
} // namespace ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册