未验证 提交 a5e00389 编写于 作者: L Liu Qi 提交者: GitHub

Merge pull request #368 from TCLResearchEurope/one_hot

Feature: Add one_hot operator.
......@@ -33,6 +33,7 @@ Operator lists
"LSTM","",""
"MATMUL","Y","Only CPU is supported."
"MAX_POOL_2D","Y",""
"ONE_HOT","Y","Only TensorFlow model is supported."
"PAD","Y",""
"PSROI_ALIGN","Y",""
"PRELU","Y","Only Caffe model is supported"
......
......@@ -79,10 +79,12 @@ MemoryBlock MemoryOptimizer::CreateMemoryBlock(
*op_def, "buffer_type", OpenCLBufferType::IN_OUT_CHANNEL));
}
std::vector<size_t> image_shape;
if (shape.size() == 2) {
if (shape.size() == 1) {
shape = {shape[0], 1, 1, 1};
} else if (shape.size() == 2) {
shape = {shape[0], 1, 1, shape[1]};
} else {
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input";
MACE_CHECK(shape.size() == 4) << "GPU only support 1D/2D/4D input";
}
OpenCLUtil::CalImage2DShape(shape, buffer_type, &image_shape);
block.set_x(image_shape[0]);
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <memory>
#include "mace/core/operator.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/one_hot.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace ops {
class OneHotOpBase : public Operation {
public:
explicit OneHotOpBase(OpConstructContext *context)
: Operation(context),
depth_(Operation::GetOptionalArg<int>("depth", 0)),
on_value_(Operation::GetOptionalArg<float>("on_value", 1)),
off_value_(Operation::GetOptionalArg<float>("off_value", 0)),
axis_(Operation::GetOptionalArg<int>("axis", -1)) {
MACE_CHECK(depth_ > 0);
}
protected:
int depth_;
float on_value_;
float off_value_;
int axis_;
};
template <DeviceType D, typename T>
class OneHotOp;
template <typename T>
class OneHotOp<DeviceType::CPU, T> : public OneHotOpBase {
public:
explicit OneHotOp(OpConstructContext *context) : OneHotOpBase(context) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
index_t axis = axis_ == -1 ? input->dim_size() : axis_;
const std::vector<index_t> &input_shape = input->shape();
std::vector<index_t> output_shape(input_shape.size() + 1);
MACE_CHECK(input->dim_size() < 100); // prevents too deep recursion
MACE_CHECK(axis >= 0 && axis <= input->dim_size());
for (size_t in = 0, out = 0; out < output_shape.size(); ++out) {
if (static_cast<index_t>(out) == axis) {
output_shape[out] = depth_;
} else {
output_shape[out] = input_shape[in];
++in;
}
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
if (input_shape.size() == 1) {
const index_t batch = input->dim(0);
if (axis == 1) {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < batch; ++i) {
for (index_t j = 0; j < depth_; ++j) {
output_ptr[i * depth_ + j] = input_ptr[i] == j ? on_value_ :
off_value_;
}
}
} else {
#pragma omp parallel for collapse(2)
for (index_t i = 0; i < depth_; ++i) {
for (index_t j = 0; j < batch; ++j) {
output_ptr[i * batch + j] = input_ptr[j] == i ? on_value_ :
off_value_;
}
}
}
} else {
run(input, &input_ptr, &output_ptr, axis, 0, 0, input_shape.size(), 0);
}
return MaceStatus::MACE_SUCCESS;
}
private:
void run(const Tensor *input, const T **input_ptr,
T **output_ptr, const index_t axis,
const index_t current_in, const index_t current_out,
const index_t left, const index_t test) const {
if (current_out == axis) {
const index_t length = depth_;
if (left == 0) {
for (index_t i = 0; i < length; ++i) {
**output_ptr = **input_ptr == i ? on_value_ : off_value_;
++(*output_ptr);
}
++(*input_ptr);
} else {
const T *in = *input_ptr;
for (index_t i = 0; i < length; ++i) {
*input_ptr = in;
run(input, input_ptr, output_ptr, axis, current_in,
current_out + 1, left - 1, i);
}
}
} else {
const index_t length = input->dim(current_in);
if (left == 0) {
for (index_t i = 0; i < length; ++i) {
**output_ptr = **input_ptr == test ? on_value_ : off_value_;
++(*output_ptr);
++(*input_ptr);
}
} else {
for (index_t i = 0; i < length; ++i) {
run(input, input_ptr, output_ptr, axis, current_in + 1,
current_out + 1, left - 1, test);
}
}
}
}
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
class OneHotOp<DeviceType::GPU, T> : public OneHotOpBase {
public:
explicit OneHotOp(OpConstructContext *context) : OneHotOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::OneHotKernel<T>(
depth_, on_value_, off_value_, axis_));
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLOneHotKernel> kernel_;
};
#endif // MACE_ENABLE_OPENCL
void RegisterOneHot(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "OneHot", OneHotOp, DeviceType::CPU, float);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP(op_registry, "OneHot", OneHotOp, DeviceType::GPU, float);
MACE_REGISTER_OP(op_registry, "OneHot", OneHotOp, DeviceType::GPU, half);
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("OneHot")
.SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
if (op->output_shape(0).dims_size() != 2) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
#endif // MACE_ENABLE_OPENCL
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void OneHot(int iters, int batch, int depth, int axis) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch});
OpDefBuilder("OneHot", "OneHotTest")
.Input("Input")
.Output("Output")
.AddIntArg("depth", depth)
.AddIntArg("axis", axis)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.Run();
}
net.Sync();
}
} // namespace
#define MACE_BM_ONE_HOT_MACRO(N, DEPTH, AXIS, TYPE, DEVICE) \
static void MACE_BM_ONE_HOT_##N##_##DEPTH##_##AXIS##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N; \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
OneHot<DEVICE, TYPE>(iters, N, DEPTH, AXIS); \
} \
MACE_BENCHMARK(MACE_BM_ONE_HOT_##N##_##DEPTH##_##AXIS##_##TYPE##_##DEVICE)
#define MACE_BM_ONE_HOT(N, DEPTH, AXIS) \
MACE_BM_ONE_HOT_MACRO(N, DEPTH, AXIS, float, CPU); \
MACE_BM_ONE_HOT_MACRO(N, DEPTH, AXIS, float, GPU); \
MACE_BM_ONE_HOT_MACRO(N, DEPTH, AXIS, half, GPU);
MACE_BM_ONE_HOT(512, 16, 0);
MACE_BM_ONE_HOT(512, 16, 1);
MACE_BM_ONE_HOT(5000, 5000, 0);
MACE_BM_ONE_HOT(5000, 5000, 1);
MACE_BM_ONE_HOT(15000, 500, 0);
MACE_BM_ONE_HOT(15000, 500, 1);
MACE_BM_ONE_HOT(15000, 5000, 0);
MACE_BM_ONE_HOT(15000, 5000, 1);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class OneHotTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
void TestOneHot(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data,
const int depth,
const int axis,
const int on_value = 1,
const int off_value = 0) {
// Construct graph
OpsTestNet net;
std::string input("Input");
std::string output("Output");
// Add input data
net.AddInputFromArray<D, float>(input, input_shape, input_data);
OpDefBuilder("OneHot", "OneHotTest")
.Input(input)
.Output(output)
.AddIntArg("depth", depth)
.AddFloatArg("on_value", on_value)
.AddFloatArg("off_value", off_value)
.AddIntArg("axis", axis)
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
auto actual = net.GetTensor(output.c_str());
auto expected = net.CreateTensor<float>(expected_shape, expected_data);
ExpectTensorNear<float>(*expected, *actual, 1e-5);
}
} // namespace
TEST_F(OneHotTest, Dim1) {
const std::vector<index_t> input_shape{10};
const std::vector<float> input_data{1, 3, 1, 8, 3, 2, 2, 3, 1, 2};
std::vector<index_t> expected_shape{10, 5};
std::vector<float> expected_data{
0, 1, 0, 0, 0,
0, 0, 0, 1, 0,
0, 1, 0, 0, 0,
0, 0, 0, 0, 0,
0, 0, 0, 1, 0,
0, 0, 1, 0, 0,
0, 0, 1, 0, 0,
0, 0, 0, 1, 0,
0, 1, 0, 0, 0,
0, 0, 1, 0, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 5, -1);
TestOneHot<DeviceType::GPU, float>(input_shape, input_data, expected_shape,
expected_data, 5, -1);
TestOneHot<DeviceType::GPU, half>(input_shape, input_data, expected_shape,
expected_data, 5, -1);
expected_shape = {5, 10};
expected_data = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 1, 0, 0, 0, 0, 0, 1, 0,
0, 0, 0, 0, 0, 1, 1, 0, 0, 1,
0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 5, 0);
TestOneHot<DeviceType::GPU, float>(input_shape, input_data, expected_shape,
expected_data, 5, 0);
TestOneHot<DeviceType::GPU, half>(input_shape, input_data, expected_shape,
expected_data, 5, 0);
}
TEST_F(OneHotTest, OnOffValue) {
const std::vector<index_t> input_shape{3};
const std::vector<float> input_data{0, 2, 5};
const std::vector<index_t> expected_shape{3, 6};
const std::vector<float> expected_data{
7, 8, 8, 8, 8, 8,
8, 8, 7, 8, 8, 8,
8, 8, 8, 8, 8, 7,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 6, -1, 7, 8);
TestOneHot<DeviceType::GPU, float>(input_shape, input_data, expected_shape,
expected_data, 6, -1, 7, 8);
TestOneHot<DeviceType::GPU, half>(input_shape, input_data, expected_shape,
expected_data, 6, -1, 7, 8);
}
TEST_F(OneHotTest, Dim2) {
const std::vector<index_t> input_shape{2, 3};
const std::vector<float> input_data{
1, 3, 2,
0, 1, 1,
};
std::vector<index_t> expected_shape{4, 2, 3};
std::vector<float> expected_data{
0, 0, 0,
1, 0, 0,
1, 0, 0,
0, 1, 1,
0, 0, 1,
0, 0, 0,
0, 1, 0,
0, 0, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 4, 0);
expected_shape = {2, 4, 3};
expected_data = {
0, 0, 0,
1, 0, 0,
0, 0, 1,
0, 1, 0,
1, 0, 0,
0, 1, 1,
0, 0, 0,
0, 0, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 4, 1);
expected_shape = {2, 3, 4};
expected_data = {
0, 1, 0, 0,
0, 0, 0, 1,
0, 0, 1, 0,
1, 0, 0, 0,
0, 1, 0, 0,
0, 1, 0, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 4, 2);
}
TEST_F(OneHotTest, Dim3) {
const std::vector<index_t> input_shape{2, 3, 4};
const std::vector<float> input_data{
3, 1, 3, 0,
0, 1, 3, 1,
2, 2, 1, 0,
1, 2, 0, 1,
3, 2, 1, 1,
0, 1, 3, 0,
};
std::vector<index_t> expected_shape{4, 2, 3, 4};
std::vector<float> expected_data{
0, 0, 0, 1,
1, 0, 0, 0,
0, 0, 0, 1,
0, 0, 1, 0,
0, 0, 0, 0,
1, 0, 0, 1,
0, 1, 0, 0,
0, 1, 0, 1,
0, 0, 1, 0,
1, 0, 0, 1,
0, 0, 1, 1,
0, 1, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
1, 1, 0, 0,
0, 1, 0, 0,
0, 1, 0, 0,
0, 0, 0, 0,
1, 0, 1, 0,
0, 0, 1, 0,
0, 0, 0, 0,
0, 0, 0, 0,
1, 0, 0, 0,
0, 0, 1, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 4, 0);
expected_shape = {2, 4, 3, 4};
expected_data = {
0, 0, 0, 1,
1, 0, 0, 0,
0, 0, 0, 1,
0, 1, 0, 0,
0, 1, 0, 1,
0, 0, 1, 0,
0, 0, 0, 0,
0, 0, 0, 0,
1, 1, 0, 0,
1, 0, 1, 0,
0, 0, 1, 0,
0, 0, 0, 0,
0, 0, 1, 0,
0, 0, 0, 0,
1, 0, 0, 1,
1, 0, 0, 1,
0, 0, 1, 1,
0, 1, 0, 0,
0, 1, 0, 0,
0, 1, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
1, 0, 0, 0,
0, 0, 1, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 4, 1);
expected_shape = {2, 3, 4, 4};
expected_data = {
0, 0, 0, 1,
0, 1, 0, 0,
0, 0, 0, 0,
1, 0, 1, 0,
1, 0, 0, 0,
0, 1, 0, 1,
0, 0, 0, 0,
0, 0, 1, 0,
0, 0, 0, 1,
0, 0, 1, 0,
1, 1, 0, 0,
0, 0, 0, 0,
0, 0, 1, 0,
1, 0, 0, 1,
0, 1, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 1, 1,
0, 1, 0, 0,
1, 0, 0, 0,
1, 0, 0, 1,
0, 1, 0, 0,
0, 0, 0, 0,
0, 0, 1, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 4, 2);
expected_shape = {2, 3, 4, 4};
expected_data = {
0, 0, 0, 1,
0, 1, 0, 0,
0, 0, 0, 1,
1, 0, 0, 0,
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 0, 1,
0, 1, 0, 0,
0, 0, 1, 0,
0, 0, 1, 0,
0, 1, 0, 0,
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 1, 0,
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 0, 1,
0, 0, 1, 0,
0, 1, 0, 0,
0, 1, 0, 0,
1, 0, 0, 0,
0, 1, 0, 0,
0, 0, 0, 1,
1, 0, 0, 0,
};
TestOneHot<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, 4, 3);
}
TEST_F(OneHotTest, CPUFallback) {
for (int dim = 1; dim < 7; ++dim) {
std::vector<index_t> shape_in(dim, 1);
std::vector<index_t> shape_out(dim + 1, 1);
OpsTestNet net;
net.AddRepeatedInput<DeviceType::GPU, float>("Input", shape_in, 0);
OpDefBuilder("OneHot", "OneHotTest")
.Input("Input")
.Output("Output")
.OutputShape(shape_out)
.AddIntArg("depth", 1)
.Finalize(net.NewOperatorDef());
net.RunOp(DeviceType::GPU);
}
}
} // namespace test
} // namespace ops
} // namespace mace
#include <common.h>
__kernel void one_hot(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input,
__write_only image2d_t output,
#ifdef AXIS_0
__private const int in_size,
#endif
__private const float on_value,
__private const float off_value) {
const int channel_idx = get_global_id(0);
const int batch_idx = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (channel_idx >= global_size_dim0 || batch_idx >= global_size_dim1) {
return;
}
#endif
DATA_TYPE4 out = off_value;
#ifdef AXIS_0
int in_idx = channel_idx * 4;
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(0, in_idx));
if (in.s0 == batch_idx) {
out.s0 = on_value;
}
if (++in_idx < in_size) {
in = READ_IMAGET(input, SAMPLER, (int2)(0, in_idx));
if (in.s0 == batch_idx) {
out.s1 = on_value;
}
if (++in_idx < in_size) {
in = READ_IMAGET(input, SAMPLER, (int2)(0, in_idx));
if (in.s0 == batch_idx) {
out.s2 = on_value;
}
if (++in_idx < in_size) {
in = READ_IMAGET(input, SAMPLER, (int2)(0, in_idx));
if (in.s0 == batch_idx) {
out.s3 = on_value;
}
}
}
}
#else
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(0, batch_idx));
int i = in.s0;
if (i / 4 == channel_idx) {
switch (i % 4) {
case 0:
out.s0 = on_value;
break;
case 1:
out.s1 = on_value;
break;
case 2:
out.s2 = on_value;
break;
case 3:
out.s3 = on_value;
break;
}
}
#endif
WRITE_IMAGET(output, (int2)(channel_idx, batch_idx), out);
}
......@@ -34,6 +34,8 @@ std::vector<index_t> FormatBufferShape(
return buffer_shape;
} else if (buffer_shape_size == 2) { // NC
return {buffer_shape[0], 1, 1, buffer_shape[1]};
} else if (buffer_shape_size == 1) { // N
return {buffer_shape[0], 1, 1, 1};
} else {
LOG(FATAL) << "GPU only support 2D or 4D input and output";
}
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_ONE_HOT_H_
#define MACE_OPS_OPENCL_IMAGE_ONE_HOT_H_
#include "mace/ops/opencl/one_hot.h"
#include <memory>
#include <vector>
#include <set>
#include <string>
#include "mace/core/op_context.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/helper.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
template <typename T>
class OneHotKernel : public OpenCLOneHotKernel {
public:
OneHotKernel(const int depth, const float on_value,
const float off_value, const int axis)
: depth_(depth), on_value_(on_value),
off_value_(off_value), axis_(axis) {}
MaceStatus Compute(
OpContext *context,
const Tensor *input,
Tensor *output) override;
private:
int depth_;
float on_value_;
float off_value_;
int axis_;
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector<index_t> input_shape_;
};
template <typename T>
MaceStatus OneHotKernel<T>::Compute(
OpContext *context,
const Tensor *input,
Tensor *output) {
auto input_shape = input->shape();
index_t axis = axis_ == -1 ? input->dim_size() : axis_;
MACE_CHECK(input->dim_size() == 1, "OneHot GPU only supports 1D input");
MACE_CHECK(axis >= 0 && axis <= input->dim_size());
std::vector<index_t> output_shape =
axis == 0 ? std::vector<index_t>{depth_, input_shape[0]} :
std::vector<index_t>{input_shape[0], depth_};
std::vector<size_t> output_image_shape{
static_cast<size_t>(RoundUpDiv4(output_shape[1])),
static_cast<size_t>(output_shape[0])};
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("one_hot");
built_options.emplace("-Done_hot=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
if (axis == 0) {
built_options.emplace("-DAXIS_0");
}
MACE_RETURN_IF_ERROR(runtime->BuildKernel("one_hot", kernel_name,
built_options, &kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
const uint32_t gws[2] = {
static_cast<uint32_t>(output_image_shape[0]),
static_cast<uint32_t>(output_image_shape[1])
};
MACE_OUT_OF_RANGE_INIT(kernel_);
if (!IsVecEqual(input_shape_, input->shape())) {
int idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_2D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
if (axis == 0) {
kernel_.setArg(idx++, static_cast<int>(input_shape[0]));
}
kernel_.setArg(idx++, on_value_);
kernel_.setArg(idx++, off_value_);
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 64, 64, 0};
std::string tuning_key = Concat("one_hot", output->dim(0), output->dim(1));
MACE_RETURN_IF_ERROR(TuningOrRun2DKernel(runtime, kernel_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_ONE_HOT_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_ONE_HOT_H_
#define MACE_OPS_OPENCL_ONE_HOT_H_
#include "mace/public/mace.h"
#include "mace/utils/utils.h"
namespace mace {
class OpContext;
class Tensor;
namespace ops {
class OpenCLOneHotKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLOneHotKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_ONE_HOT_H_
......@@ -43,6 +43,7 @@ extern void RegisterIdentity(OpRegistryBase *op_registry);
extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterOneHot(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPNorm(OpRegistryBase *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry);
......@@ -110,6 +111,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterInferConv2dShape(this);
ops::RegisterLocalResponseNorm(this);
ops::RegisterMatMul(this);
ops::RegisterOneHot(this);
ops::RegisterPad(this);
ops::RegisterPNorm(this);
ops::RegisterPooling(this);
......
......@@ -128,6 +128,7 @@ MaceSupportedOps = [
'LSTMCell',
# 'LstmNonlinear',
'MatMul',
'OneHot',
'Pad',
'PNorm',
'Pooling',
......
......@@ -117,6 +117,7 @@ TFSupportedOps = [
'Sqrt',
'MirrorPad',
'Cumsum',
'OneHot',
]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
......@@ -255,6 +256,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.MirrorPad.name: self.convert_pad,
TFOpType.Cumsum.name: self.convert_cumsum,
TFOpType.OneHot.name: self.convert_one_hot,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -576,6 +578,29 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
op.type = MaceOp.BiasAdd.name
def convert_one_hot(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.OneHot.name
depth_arg = op.arg.add()
depth_arg.name = 'depth'
depth_arg.i = tf_op.inputs[1].eval().astype(np.int32)
on_value_arg = op.arg.add()
on_value_arg.name = 'on_value'
on_value_arg.f = tf_op.inputs[2].eval().astype(np.float32)
off_value_arg = op.arg.add()
off_value_arg.name = 'off_value'
off_value_arg.f = tf_op.inputs[3].eval().astype(np.float32)
axis_arg = op.arg.add()
axis_arg.name = tf_axis
axis_arg.i = tf_op.get_attr(tf_axis)
self._skip_tensor.update([inp.name for inp in tf_op.inputs][1:])
del op.input[1:]
def convert_add(self, tf_op):
if len(tf_op.inputs) == 2:
self.convert_elementwise(tf_op)
......
......@@ -50,6 +50,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/fully_connected.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/lstmcell.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/matmul.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/one_hot.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pad.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pooling.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pooling_buffer.cl"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册