rand.cpp 15.7 KB
Newer Older
1 2 3 4
/**
 * \file src/opr/impl/rand.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/rand.h"
#include "megbrain/graph/grad_impl.h"
M
Megvii Engine Team 已提交
14
#include "megbrain/opr/utility.h"
15 16 17 18 19 20 21

#include "./internal/megdnn_opr_wrapper.inl"

using namespace mgb;
using namespace opr;
using namespace intl;

M
Megvii Engine Team 已提交
22 23 24 25
template <typename MegDNNOpr>
RNGOprBase<MegDNNOpr>::RNGOprBase(
        const OperatorNodeBaseCtorParam& opr, const Param& param)
        : Super(opr), m_param(param) {}
26

M
Megvii Engine Team 已提交
27
template <class MegDNNOpr>
28 29 30 31
UniqPtrWithCN<MegDNNOpr> RNGOprBase<MegDNNOpr>::create_megdnn_opr() {
    auto opr = intl::create_megdnn_opr<MegDNNOpr>(comp_node());
    opr->param() = param();
    return opr;
32 33
}

M
Megvii Engine Team 已提交
34
template <typename MegDNNOpr>
35
void RNGOprBase<MegDNNOpr>::ensure_megdnn_opr() {
36
    if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) {
37 38
        // activate comp_node for curandCreateGenerator in create_megdnn_opr
        comp_node().activate();
39
        m_dnn_opr = create_megdnn_opr();
40 41 42
    }
}

43
/* ================= RNG with shape =================  */
M
Megvii Engine Team 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
#define _INST_RNG_OPR_WITH_SHAPE(RNGOpr, name)                                        \
    MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr);                                              \
    cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const {               \
        auto prop = Super::do_make_node_prop();                                       \
        prop->add_flag(NodeProp::Flag::IMPURE_FUNC);                                  \
        prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE});               \
        for (auto i : input()) {                                                      \
            prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \
        }                                                                             \
        return prop;                                                                  \
    }                                                                                 \
    RNGOpr::RNGOpr(                                                                   \
            VarNode* shape, const Param& param, const OperatorNodeConfig& config)     \
            : Super({shape->owner_graph(), config, (name), {shape}}, param) {         \
        DType dtype = DType::from_enum(param.dtype);                                  \
        add_input({shape});                                                           \
        add_output(None)->dtype(dtype).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);    \
        cg::add_workspace_output(this);                                               \
        add_equivalence_component<ScalarHash<void*>>(this);                           \
    }                                                                                 \
    SymbolVar RNGOpr::make(                                                           \
            SymbolVar shape, const Param& param, const OperatorNodeConfig& config) {  \
        return shape.insert_single_output_opr<RNGOpr>(shape.node(), param, config);   \
    }                                                                                 \
    void RNGOpr::init_output_static_infer_desc() {                                    \
        using namespace cg::static_infer;                                             \
        auto&& mgr = owner_graph()->static_infer_manager();                           \
        auto infer_out = [](TensorShape& dest, const InpVal& inp) {                   \
            cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value());              \
            return true;                                                              \
        };                                                                            \
        auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {                \
            ensure_megdnn_opr();                                                      \
            dest.ndim = 1;                                                            \
            dest.shape[0] = m_dnn_opr->get_workspace_in_bytes(                        \
                    {inp.val.at(0).shape(), output(0)->dtype()});                     \
            return true;                                                              \
        };                                                                            \
        mgr.register_shape_infer(                                                     \
                output(0),                                                            \
                {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out});          \
        mgr.register_shape_infer(                                                     \
                output(1),                                                            \
                {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk});          \
    }                                                                                 \
    void RNGOpr::scn_do_execute() {                                                   \
        auto&& ret = output(0);                                                       \
        if (ret->layout().is_empty()) {                                               \
            mgb_assert(ret->dev_tensor().empty());                                    \
            return;                                                                   \
        }                                                                             \
        m_dnn_opr->exec(                                                              \
                ret->dev_tensor().as_megdnn(),                                        \
                get_megdnn_workspace_from_var(output(1)));                            \
    }
99

M
Megvii Engine Team 已提交
100 101 102
_INST_RNG_OPR_WITH_SHAPE(UniformRNG, "uniform_rng")
_INST_RNG_OPR_WITH_SHAPE(GaussianRNG, "gaussian_rng")
_INST_RNG_OPR_WITH_SHAPE(PermutationRNG, "permutation_rng")
103 104 105 106
#undef _INST_RNG_OPR_WITH_SHAPE

/* ================= RNG with input =================  */
#define _AS_MEGDNN(idx) input((idx))->dev_tensor().as_megdnn()
M
Megvii Engine Team 已提交
107 108 109 110
#define _INFER_WK_DEPS(idx) \
    { input((idx)), DepType::SHAPE }
#define _INFER_WK_ARGS(idx) \
    { inp.val.at((idx)).shape(), input((idx))->dtype() }
111

M
Megvii Engine Team 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
#define _INST_RNG_OPR_WITH_INPUT(RNGOpr, name)                                         \
    MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr);                                               \
    RNGOpr::RNGOpr(                                                                    \
            _INPUTS(VarNode*, ), const Param& param, const OperatorNodeConfig& config) \
            : Super({i0->owner_graph(), config, (name), {_INPUTS(, )}}, param) {       \
        add_input({_INPUTS(, )});                                                      \
        add_output(None)                                                               \
                ->dtype(i0->dtype())                                                   \
                .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);                           \
        cg::add_workspace_output(this);                                                \
        add_equivalence_component<ScalarHash<void*>>(this);                            \
    }                                                                                  \
    SymbolVar RNGOpr::make(                                                            \
            _INPUTS(SymbolVar, ), const Param& param,                                  \
            const OperatorNodeConfig& config) {                                        \
        return i0.insert_single_output_opr<RNGOpr>(_INPUTS(, .node()), param, config); \
    }                                                                                  \
    void RNGOpr::init_output_static_infer_desc() {                                     \
        using namespace cg::static_infer;                                              \
        auto&& mgr = owner_graph()->static_infer_manager();                            \
        auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {                 \
            ensure_megdnn_opr();                                                       \
            dest.ndim = 1;                                                             \
            dest.shape[0] = m_dnn_opr->get_workspace_in_bytes(                         \
                    _FOR_EACH(_INFER_WK_ARGS),                                         \
                    {output(0)->shape(), output(0)->dtype()});                         \
            return true;                                                               \
        };                                                                             \
        mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));  \
        mgr.register_shape_infer(                                                      \
                output(1), {SourceType::DEP, {_FOR_EACH(_INFER_WK_DEPS)}, infer_wk});  \
    }                                                                                  \
    void RNGOpr::add_input_layout_constraint() {                                       \
        for (auto i : input())                                                         \
            i->add_layout_constraint_contiguous();                                     \
    };                                                                                 \
    void RNGOpr::scn_do_execute() {                                                    \
        auto&& ret = output(0);                                                        \
        if (ret->layout().is_empty()) {                                                \
            mgb_assert(ret->dev_tensor().empty());                                     \
            return;                                                                    \
        }                                                                              \
        m_dnn_opr->exec(                                                               \
                _FOR_EACH(_AS_MEGDNN), output(0)->dev_tensor().as_megdnn(),            \
                get_megdnn_workspace_from_var(output(1)));                             \
    }                                                                                  \
    cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const {                \
        auto prop = Super::do_make_node_prop();                                        \
        prop->add_flag(NodeProp::Flag::IMPURE_FUNC);                                   \
        for (auto i : input()) {                                                       \
            prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY);  \
        }                                                                              \
        return prop;                                                                   \
    }
166 167 168

/* ================= 1 input =================  */
#define _INPUTS(prefix, subfix) prefix i0 subfix
M
Megvii Engine Team 已提交
169 170
#define _FOR_EACH(cb)           cb(0)
_INST_RNG_OPR_WITH_INPUT(PoissonRNG, "poisson_rng")
171 172 173 174
#undef _INPUTS
#undef _FOR_EACH

/* ================= 2 input =================  */
M
Megvii Engine Team 已提交
175 176 177 178
#define _INPUTS(prefix, subfix) prefix i0 subfix, prefix i1 subfix
#define _FOR_EACH(cb)           cb(0), cb(1)
_INST_RNG_OPR_WITH_INPUT(BetaRNG, "beta_rng")
_INST_RNG_OPR_WITH_INPUT(GammaRNG, "gamma_rng")
179 180 181 182 183 184 185
#undef _INPUTS
#undef _FOR_EACH

#undef _AS_MEGDNN
#undef _INFER_WK_DEPS
#undef _INFER_WK_ARGS
#undef _INST_RNG_OPR_WITH_INPUT
186

M
Megvii Engine Team 已提交
187 188 189 190
#define IMPL(_cls)                              \
    MGB_IMPL_OPR_GRAD(_cls) {                   \
        MGB_MARK_USED_VAR(out_grad);            \
        return InvalidGrad::make(opr, wrt_idx); \
191
    }
192 193 194 195

namespace mgb {
namespace opr {
namespace intl {
196 197 198 199 200 201
template class RNGOprBase<::megdnn::GaussianRNG>;
template class RNGOprBase<::megdnn::UniformRNG>;
template class RNGOprBase<::megdnn::GammaRNG>;
template class RNGOprBase<::megdnn::PermutationRNG>;
template class RNGOprBase<::megdnn::BetaRNG>;
template class RNGOprBase<::megdnn::PoissonRNG>;
202 203
template class RNGOprBase<::megdnn::ShuffleRNGForward>;
template class RNGOprBase<::megdnn::ShuffleRNGBackward>;
204
#if MGB_ENABLE_GRAD
205 206
IMPL(GaussianRNG);
IMPL(UniformRNG);
207 208 209 210
IMPL(GammaRNG);
IMPL(PoissonRNG);
IMPL(PermutationRNG);
IMPL(BetaRNG);
211
#endif
212 213 214 215 216 217 218 219
}  // namespace intl
}  // namespace opr
}  // namespace mgb

/* ================= ShuffleRNGForward =================  */

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGForward);

M
Megvii Engine Team 已提交
220 221
ShuffleRNGForward::ShuffleRNGForward(
        VarNode* data, const Param& param, const OperatorNodeConfig& config)
222 223
        : Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) {
    add_input({data});
M
Megvii Engine Team 已提交
224 225
    add_output(None)->dtype(data->dtype()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
    add_output(None)->dtype(dtype::Int32{}).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
226 227 228 229
    cg::add_workspace_output(this);
    add_equivalence_component<ScalarHash<void*>>(this);
}

M
Megvii Engine Team 已提交
230 231
SymbolVarArray ShuffleRNGForward::make(
        SymbolVar in_tensor, const Param& param, const OperatorNodeConfig& config) {
232
    auto node = in_tensor.node()->owner_graph()->insert_opr(
M
Megvii Engine Team 已提交
233
            std::make_unique<ShuffleRNGForward>(in_tensor.node(), param, config));
234 235
    mgb_assert(node->output().size() == 3);
    return {node->output(0), node->output(1)};
236
}
237 238 239 240 241

void ShuffleRNGForward::init_output_static_infer_desc() {
    using namespace cg::static_infer;
    auto&& mgr = owner_graph()->static_infer_manager();

M
Megvii Engine Team 已提交
242
    mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0)));
243 244 245

    auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) {
        TensorLayout o0, o1;
M
Megvii Engine Team 已提交
246
        m_dnn_opr->deduce_layout({iv.val[0].shape(), input(0)->dtype()}, o0, o1);
247 248 249 250
        dest = o1;
        return true;
    };
    mgr.register_shape_infer(
M
Megvii Engine Team 已提交
251
            output(1), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1});
252 253 254 255 256 257 258 259 260 261 262

    auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {
        ensure_megdnn_opr();
        dest.ndim = 1;
        dest.shape[0] = m_dnn_opr->get_workspace_in_bytes(
                {inp.val[0].shape(), input(0)->dtype()},
                {output(0)->shape(), output(0)->dtype()},
                {output(1)->shape(), output(1)->dtype()});
        return true;
    };
    mgr.register_shape_infer(
M
Megvii Engine Team 已提交
263
            output(2), {SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk});
264 265
}

266 267 268 269 270
void ShuffleRNGForward::add_input_layout_constraint() {
    input(0)->add_layout_constraint_contiguous();
};

void ShuffleRNGForward::scn_do_execute() {
271 272 273 274 275
    auto&& ret = output(0);
    if (ret->layout().is_empty()) {
        mgb_assert(ret->dev_tensor().empty());
        return;
    }
M
Megvii Engine Team 已提交
276 277 278 279
    m_dnn_opr->exec(
            input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
            output(1)->dev_tensor().as_megdnn(),
            get_megdnn_workspace_from_var(output(2)));
280 281
}

282 283 284 285
cg::OperatorNodeBase::NodeProp* ShuffleRNGForward::do_make_node_prop() const {
    auto prop = Super::do_make_node_prop();
    prop->add_flag(NodeProp::Flag::IMPURE_FUNC);
    for (auto i : input()) {
M
Megvii Engine Team 已提交
286
        prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY);
287 288 289 290
    }
    return prop;
}

291 292 293 294 295 296 297 298 299 300 301 302
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ShuffleRNGForward) {
    mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
    if (!out_grad[0])
        return nullptr;
    return ShuffleRNGBackward::make(out_grad[0], opr.output(1), opr.input(0)).node();
}
#endif

MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward);
MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true)

303
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}