提交 ca811c2c 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mgb/opr/reduce): fix grad function of reduce mean and add testcase

GitOrigin-RevId: 390854bb2f2ae4032638f89118e8d5d946fea77a
上级 ec9de227
......@@ -78,8 +78,10 @@ void ReduceForward::check_exec(const TensorLayout& src, const TensorLayout& dst,
megdnn_assert(dst.shape[i] == 1_z, "%s", errmsg().c_str());
}
}
megdnn_assert(src.dtype.category() == dst.dtype.category(),
"the category of reduce output and input must be the same");
megdnn_assert(src.dtype.category() == dst.dtype.category() ||
param().data_type == Reduce::DataType::FLOAT_O32xC32,
"the category of reduce output and input must be the same,"
" or the data_type is FLOAT_O32xC32");
if (param().data_type == DataType::DEFAULT) {
megdnn_assert(src.dtype == dst.dtype &&
(src.dtype.category() == DTypeCategory::FLOAT ||
......@@ -89,8 +91,11 @@ void ReduceForward::check_exec(const TensorLayout& src, const TensorLayout& dst,
megdnn_assert(src.dtype.enumv() == DTypeEnum::Quantized8Asymm);
} else if (param().data_type == DataType::QINT_I8xO32) {
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8);
} else {
} else if (param().data_type == DataType::FLOAT_IO16xC32 ||
param().data_type == DataType::FLOAT_O16xC32) {
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
} else {
megdnn_assert(param().data_type == DataType::FLOAT_O32xC32);
}
auto expected = get_out_dtype(param().data_type, src.dtype);
......
......@@ -28,6 +28,7 @@ size_t dispatch_dtype_workspace(const TensorLayout& src, const TensorLayout&,
Reduce::DataType data_type) {
using f16 = DTypeTrait<dtype::Float16>::ctype;
using f32 = DTypeTrait<dtype::Float32>::ctype;
using i32 = DTypeTrait<dtype::Int32>::ctype;
if (data_type == Reduce::DataType::DEFAULT) {
#define cb(_dt) \
case DTypeTrait<_dt>::enumv: { \
......@@ -46,6 +47,8 @@ size_t dispatch_dtype_workspace(const TensorLayout& src, const TensorLayout&,
return get_reduce_workspace_in_bytes<Op<f16, f32, f32>>(A, B, C);
else if (src.dtype == dtype::Float32())
return get_reduce_workspace_in_bytes<Op<f32, f32, f32>>(A, B, C);
else if (src.dtype == dtype::Int32())
return get_reduce_workspace_in_bytes<Op<i32, f32, f32>>(A, B, C);
} else if (data_type == Reduce::DataType::FLOAT_O16xC32) {
if (src.dtype == dtype::Float16())
return get_reduce_workspace_in_bytes<Op<f16, f16, f32>>(A, B, C);
......@@ -61,6 +64,7 @@ void dispatch_dtype(cudaStream_t stream, const TensorND& src,
size_t B, size_t C, Reduce::DataType data_type) {
using f16 = DTypeTrait<dtype::Float16>::ctype;
using f32 = DTypeTrait<dtype::Float32>::ctype;
using i32 = DTypeTrait<dtype::Int32>::ctype;
if (data_type == Reduce::DataType::DEFAULT) {
switch (src.layout.dtype.enumv()) {
#define cb(_dt) \
......@@ -80,10 +84,14 @@ void dispatch_dtype(cudaStream_t stream, const TensorND& src,
return run_reduce<Op<f16, f32, f32>, false>(
workspace.ptr<f32>(), A, B, C, stream,
Op<f16, f32, f32>(src.ptr<f16>(), dst.ptr<f32>(), B));
} else {
} else if (src.layout.dtype == dtype::Float32()) {
return run_reduce<Op<f32, f32, f32>, false>(
workspace.ptr<f32>(), A, B, C, stream,
Op<f32, f32, f32>(src.ptr<f32>(), dst.ptr<f32>(), B));
} else if (src.layout.dtype == dtype::Int32()) {
return run_reduce<Op<i32, f32, f32>, false>(
workspace.ptr<f32>(), A, B, C, stream,
Op<i32, f32, f32>(src.ptr<i32>(), dst.ptr<f32>(), B));
}
} else if (data_type == Reduce::DataType::FLOAT_O16xC32) {
if (src.layout.dtype == dtype::Float16()) {
......
......@@ -36,6 +36,7 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
INST(dt_float16, dt_float16, float)
INST(dt_float16, float, float)
INST(float, dt_float16, float)
INST(int, float, float)
#undef cb
#undef INST
......
......@@ -80,6 +80,8 @@ TEST_F(CUDA, REDUCE) {
}
check(mode, dtype::Float16(), dtype::Float32(),
Reduce::DataType::FLOAT_O32xC32);
check(mode, dtype::Int32(), dtype::Float32(),
Reduce::DataType::FLOAT_O32xC32);
check(mode, dtype::Float16(), dtype::Float16(),
Reduce::DataType::FLOAT_O16xC32);
check(mode, dtype::Float32(), dtype::Float16(),
......
......@@ -50,6 +50,10 @@ TEST_F(FALLBACK, REDUCE) {
param.data_type = DataType::FLOAT_O32xC32;
config = Config(param, dtype, shape);
configs.push_back(config);
} else if (dtype == dtype::Int32()) {
Param param(mode, axis, DataType::FLOAT_O32xC32);
Config config(param, dtype, shape);
configs.push_back(config);
}
}
// large (ABC) -> (A1C) case
......
......@@ -1680,7 +1680,7 @@ void Reduce::create_megdnn_opr() {
MGB_IMPL_OPR_GRAD(Reduce) {
for (size_t i = 1; i < opr.output().size(); ++ i)
mgb_assert(!out_grad[i]);
if (wrt_idx)
if (wrt_idx || opr.input(0)->dtype().category() != DTypeCategory::FLOAT)
return InvalidGrad::make(opr, wrt_idx);
SymbolVar og{out_grad[0]}, iv{opr.input(0)}, ov{opr.output(0)};
constexpr auto cmv = Elemwise::Mode::COND_LEQ_MOV;
......@@ -1700,8 +1700,9 @@ MGB_IMPL_OPR_GRAD(Reduce) {
case Mode::MEAN: {
auto og_shape = opr::GetVarShape::make(og),
iv_shape = opr::GetVarShape::make(iv),
scale = opr::reduce_prod(og_shape, og_shape.make_scalar(1)) /
opr::reduce_prod(iv_shape, iv_shape.make_scalar(1));
scale = div(
opr::reduce_prod(og_shape, og_shape.make_scalar(1)),
opr::reduce_prod(iv_shape, iv_shape.make_scalar(1)));
return scale * Broadcast::make(og, GetVarShape::make(iv));
}
default:
......
......@@ -27,6 +27,7 @@ using namespace mgb;
namespace {
using Mode = opr::Reduce::Mode;
using DataType = opr::Reduce::Param::DataType;
template<Mode mode, typename ctype>
struct ImplTrait {
......@@ -43,6 +44,10 @@ namespace {
static ctype reduce(ctype accum, ctype v) {
return accum + v;
}
ctype finalize(ctype result) {
return result;
}
};
template<typename ctype>
......@@ -56,6 +61,10 @@ namespace {
static ctype reduce(ctype accum, ctype v) {
return accum + v * v;
}
ctype finalize(ctype result) {
return result;
}
};
template<typename ctype>
......@@ -69,6 +78,10 @@ namespace {
static ctype reduce(ctype accum, ctype v) {
return accum * v;
}
ctype finalize(ctype result) {
return result;
}
};
template<typename ctype>
......@@ -82,6 +95,10 @@ namespace {
static ctype reduce(ctype accum, ctype v) {
return std::max(accum, v);
}
ctype finalize(ctype result) {
return result;
}
};
template<typename ctype>
......@@ -95,6 +112,30 @@ namespace {
static ctype reduce(ctype accum, ctype v) {
return std::min(accum, v);
}
ctype finalize(ctype result) {
return result;
}
};
template<typename ctype>
struct ImplTrait<Mode::MEAN, ctype> {
static constexpr float GRAD_MAXERR = 1e-4, GRAD_EPS = 1e-2;
size_t nr_elems;
ctype init() {
nr_elems = 0;
return 0;
}
ctype reduce(ctype accum, ctype v) {
nr_elems ++;
return accum + v;
}
ctype finalize(ctype result) {
return result / static_cast<ctype>(nr_elems);
}
};
template<Mode mode, typename ctype>
......@@ -108,10 +149,11 @@ namespace {
return;
}
ctype val = Impl::init();
Impl impl;
ctype val = impl.init();
for (auto i: megdnn::tensor_iter_valonly<ctype>(src.as_megdnn()))
val = Impl::reduce(val, i);
dest.ptr<ctype>()[0] = val;
val = impl.reduce(val, i);
dest.ptr<ctype>()[0] = impl.finalize(val);
return;
}
......@@ -143,15 +185,16 @@ namespace {
for (size_t i = 0; i < tshp.ndim; i ++)
offset += iter.idx()[i] * src.layout().stride[i];
ctype val = Impl::init();
Impl impl;
ctype val = impl.init();
auto subspec = SubTensorSpec::make_from_offset_elem(
sub_layout, offset);
HostTensorND subt = const_cast<HostTensorND&>(src).sub(subspec);
for (ctype i:
megdnn::tensor_iter_valonly<ctype>(subt.as_megdnn())) {
val = Impl::reduce(val, i);
val = impl.reduce(val, i);
}
*iter = val;
*iter = impl.finalize(val);
}
}
......@@ -535,7 +578,7 @@ TEST(TestBasicArithReduction, DifferentNDim) {
for (auto mode :
{Reduce::Mode::PRODUCT, Reduce::Mode::MAX, Reduce::Mode::MIN,
Reduce::Mode::SUM, Reduce::Mode::SUM_SQR}) {
Reduce::Mode::SUM, Reduce::Mode::SUM_SQR, Reduce::Mode::MEAN}) {
check_mode(mode);
}
}
......@@ -606,7 +649,7 @@ TEST(TestBasicArithReduction, MultiType) {
host_tshp->ptr<int>()[3] = 22;
for (auto mode :
{Reduce::Mode::PRODUCT, Reduce::Mode::MAX, Reduce::Mode::MIN,
Reduce::Mode::SUM, Reduce::Mode::SUM_SQR}) {
Reduce::Mode::SUM, Reduce::Mode::SUM_SQR, Reduce::Mode::MEAN}) {
check_mode(mode);
}
}
......@@ -682,18 +725,19 @@ TEST(TestBasicArithReduction, AutoCheck) {
Param param;
auto make_graph = [&param](const Checker::SymInpArray& inputs)
auto make_graph = [&param](const Checker::SymInpArray& inputs, DType dtype)
-> Checker::SymOutArray {
auto inp = inputs[0];
auto tshp = inputs[1].symshape();
inp = opr::TypeCvt::make(inp, dtype::Float16());
inp = opr::TypeCvt::make(inp, dtype);
return {opr::Reduce::make(inp, param, tshp)};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp,
DType dtype) {
auto cn = inp[0]->storage().comp_node();
TensorShape out_shape = inp[1]->shape();
dest[0] = HostTensorND{cn, out_shape, dtype::Float32()};
HostTensorND tmp_inp{cn, inp[0]->shape(), dtype::Float16()};
HostTensorND tmp_inp{cn, inp[0]->shape(), dtype};
HostTensorND new_inp{cn, inp[0]->shape(), dtype::Float32()};
auto typecvt =
megdnn_naive_handle()->create_operator<megdnn::TypeCvt>();
......@@ -711,31 +755,38 @@ TEST(TestBasicArithReduction, AutoCheck) {
dispatch_by_mode(ctype, Mode::MAX, in, out); \
dispatch_by_mode(ctype, Mode::SUM, in, out); \
dispatch_by_mode(ctype, Mode::PRODUCT, in, out); \
dispatch_by_mode(ctype, Mode::SUM_SQR, in, out);
dispatch_by_mode(ctype, Mode::SUM_SQR, in, out); \
dispatch_by_mode(ctype, Mode::MEAN, in, out);
mgb_assert(param.data_type == Param::DataType::FLOAT_O16xC32 ||
param.data_type == Param::DataType::FLOAT_O32xC32);
mgb_assert(param.data_type == Param::DataType::FLOAT_O32xC32);
dispatch_by_dtype(dtype::Float32, new_inp, dest[0]);
#undef dispatch_by_mode
#undef dispatch_by_dtype
};
auto check = [&](Mode mode, Param::DataType data_type) {
auto check = [&](Mode mode, Param::DataType data_type, DType dtype) {
param.mode = mode;
param.data_type = data_type;
Checker::RunOptions opts;
opts.outputs_max_err = 1e-3;
opts.numdiff_max_err = 5e-1;
Checker(make_graph, fwd)
.set_input_allow_grad(1, false)
.run({TensorShape{22, 21}, {22, 1}}, opts)
.run({TensorShape{22, 21}, {1, 1}}, opts)
.run({TensorShape{22, 21}, {22, 1}}, opts);
using namespace std::placeholders;
Checker checker(std::bind(make_graph, _1, dtype),
std::bind(fwd, _1, _2, dtype));
if (dtype.category() == DTypeCategory::FLOAT) {
checker.set_input_allow_grad(1, false);
} else {
checker.disable_grad_check();
}
checker.run({TensorShape{22, 21}, {22, 1}}, opts)
.run({TensorShape{22, 21}, {1, 1}}, opts)
.run({TensorShape{22, 21}, {22, 1}}, opts);
};
for (auto mode :
{Mode::SUM, Mode::MAX, Mode::MIN, Mode::PRODUCT}) {
check(mode, Param::DataType::FLOAT_O32xC32);
{Mode::SUM, Mode::MAX, Mode::MIN, Mode::PRODUCT, Mode::MEAN}) {
check(mode, Param::DataType::FLOAT_O32xC32, dtype::Float16());
check(mode, Param::DataType::FLOAT_O32xC32, dtype::Int32());
}
}
......@@ -747,6 +798,7 @@ OPR_TEST(SUM_SQR)
OPR_TEST(PRODUCT)
OPR_TEST(MAX)
OPR_TEST(MIN)
OPR_TEST(MEAN)
TEST(TestBasicArithReduction, CompSeqRecordLevel2) {
HostTensorGenerator<> gen;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册