提交 d4969fc0 编写于 作者: Y yejianwu

update epsilon type in batch_norm op

上级 cd081bde
......@@ -13,12 +13,13 @@ namespace kernels {
template <DeviceType D, typename T>
struct BatchNormFunctor {
T epsilon_;
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
......@@ -38,7 +39,6 @@ struct BatchNormFunctor {
Tensor::MappingGuard offset_mapper(offset);
Tensor::MappingGuard mean_mapper(mean);
Tensor::MappingGuard var_mapper(var);
Tensor::MappingGuard epsilon_mapper(epsilon);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
......@@ -46,7 +46,6 @@ struct BatchNormFunctor {
const T *offset_ptr = offset->data<T>();
const T *mean_ptr = mean->data<T>();
const T *var_ptr = var->data<T>();
const T *epsilon_ptr = epsilon->data<T>();
T *output_ptr = output->mutable_data<T>();
vector<T> new_scale(channels);
......@@ -54,7 +53,7 @@ struct BatchNormFunctor {
#pragma omp parallel for
for (index_t c = 0; c < channels; ++c) {
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr);
new_scale[c] = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon_);
new_offset[c] = offset_ptr[c] - mean_ptr[c] * new_scale[c];
}
......@@ -81,17 +80,17 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output);
template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> {
T epsilon_;
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output);
};
......
......@@ -15,7 +15,6 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output) {
// Batch normalization in the paper https://arxiv.org/abs/1502.03167 .
// The calculation formula for inference is
......@@ -34,14 +33,13 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
const float *offset_ptr = offset->data<float>();
const float *mean_ptr = mean->data<float>();
const float *var_ptr = var->data<float>();
const float *epsilon_ptr = epsilon->data<float>();
float *output_ptr = output->mutable_data<float>();
index_t count = sample_size >> 2;
index_t remain_count = sample_size - (count << 2);
#pragma omp parallel for
for (index_t c = 0; c < channel; ++c) {
float new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + *epsilon_ptr);
float new_scale = scale_ptr[c] / std::sqrt(var_ptr[c] + epsilon_);
float new_offset = offset_ptr[c] - mean_ptr[c] * new_scale;
index_t pos = c * sample_size;
......@@ -69,4 +67,4 @@ void BatchNormFunctor<DeviceType::NEON, float>::operator()(
};
} // namespace kernels
} // namespace mace
\ No newline at end of file
} // namespace mace
......@@ -18,7 +18,6 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *offset,
const Tensor *mean,
const Tensor *var,
const Tensor *epsilon,
Tensor *output) {
const index_t batch = input->dim(0);
......@@ -48,7 +47,7 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(offset->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(mean->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(var->buffer())));
bm_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(epsilon->buffer())));
bm_kernel.setArg(idx++, epsilon_);
bm_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
auto params_generator = [&kwg_size]()->std::vector<std::vector<uint32_t>> {
......
......@@ -5,7 +5,7 @@ __kernel void batch_norm(__read_only image2d_t input,
__read_only image2d_t offset,
__read_only image2d_t mean,
__read_only image2d_t var,
__global const DATA_TYPE *epsilon,
__private const DATA_TYPE epsilon,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
......@@ -17,7 +17,7 @@ __kernel void batch_norm(__read_only image2d_t input,
DATA_TYPE4 mean_value = READ_IMAGET(mean, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 var_value = READ_IMAGET(var, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 new_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)(*epsilon));
DATA_TYPE4 new_scale = scale_value * rsqrt(var_value + (DATA_TYPE4)epsilon);
DATA_TYPE4 new_offset = offset_value - mean_value * new_scale;
const int pos = ch_blk * width + w;
......
......@@ -14,7 +14,10 @@ template <DeviceType D, class T>
class BatchNormOp : public Operator<D, T> {
public:
BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), functor_() {}
: Operator<D, T>(operator_def, ws), functor_() {
functor_.epsilon_ =
OperatorBase::GetSingleArgument<float>("epsilon", static_cast<float>(-1));
}
bool Run() override {
const Tensor *input = this->Input(INPUT);
......@@ -22,7 +25,6 @@ class BatchNormOp : public Operator<D, T> {
const Tensor *offset = this->Input(OFFSET);
const Tensor *mean = this->Input(MEAN);
const Tensor *var = this->Input(VAR);
const Tensor *epsilon = this->Input(EPSILON);
MACE_CHECK(input->dim_size() == 4, "input must be 4-dimensional. ",
input->dim_size());
......@@ -34,13 +36,11 @@ class BatchNormOp : public Operator<D, T> {
mean->dim_size());
MACE_CHECK(var->dim_size() == 1, "var must be 1-dimensional. ",
var->dim_size());
MACE_CHECK(epsilon->dim_size() == 0, "epsilon must be 0-dimensional. ",
epsilon->dim_size());
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, scale, offset, mean, var, epsilon, output);
functor_(input, scale, offset, mean, var, output);
return true;
}
......@@ -48,7 +48,7 @@ class BatchNormOp : public Operator<D, T> {
kernels::BatchNormFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR, EPSILON);
OP_INPUT_TAGS(INPUT, SCALE, OFFSET, MEAN, VAR);
OP_OUTPUT_TAGS(OUTPUT);
};
......
......@@ -21,7 +21,6 @@ static void BatchNorm(
net.AddRandomInput<D, T>("Offset", {channels});
net.AddRandomInput<D, T>("Mean", {channels});
net.AddRandomInput<D, T>("Var", {channels}, true);
net.AddInputFromArray<D, float>("Epsilon", {}, {1e-3});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
......@@ -35,7 +34,7 @@ static void BatchNorm(
.Input("OffsetImage")
.Input("MeanImage")
.Input("VarImage")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
}
......@@ -46,7 +45,7 @@ static void BatchNorm(
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
}
......
......@@ -20,7 +20,6 @@ void Simple() {
net.AddInputFromArray<D, float>("Offset", {1}, {2.0});
net.AddInputFromArray<D, float>("Mean", {1}, {10});
net.AddInputFromArray<D, float>("Var", {1}, {11.67f});
net.AddInputFromArray<D, float>("Epsilon", {}, {1e-3});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
......@@ -35,7 +34,7 @@ void Simple() {
.Input("OffsetImage")
.Input("MeanImage")
.Input("VarImage")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
// Run
......@@ -50,7 +49,7 @@ void Simple() {
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
......@@ -180,7 +179,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
......@@ -190,7 +189,6 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::OPENCL, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
......@@ -212,7 +210,7 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
.Input("OffsetImage")
.Input("MeanImage")
.Input("VarImage")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
......@@ -246,7 +244,7 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
.Input("Offset")
.Input("Mean")
.Input("Var")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("Output")
.Finalize(net.NewOperatorDef());
......@@ -256,7 +254,6 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::OPENCL, float>("Offset", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Mean", {channels});
net.AddRandomInput<DeviceType::OPENCL, float>("Var", {channels}, true);
net.AddInputFromArray<DeviceType::OPENCL, float>("Epsilon", {}, {1e-3});
// run cpu
net.RunOp();
......@@ -279,7 +276,7 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
.Input("OffsetImage")
.Input("MeanImage")
.Input("VarImage")
.Input("Epsilon")
.AddFloatArg("epsilon", 1e-3)
.Output("OutputImage")
.Finalize(net.NewOperatorDef());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册