未验证 提交 79482620 编写于 作者: Z ZZK 提交者: GitHub

Debug dim scatter (#5371)

* startup of dev scatter ops

* use dim scatter base class

* refine(using binop to abstract scatter update and add

* refine (use macros to implement kerenl class and functors)

* refine(description for register scatter ops/kernels)

* refine

* add inplace ops

* python wraper scatter_add inplace

* dev inplace ops

* refine dim_gather (using macros register mechanism)

* add grad of scatter_add_like

* refine (add src, like versions for scatter)

* refine src/like tensor

* gather refine(no need outplace/inplace versions)

* reformat

* refine

* test case of dim scatter

* test case for dim_scatter_add_like

* 1n2d test case for dim_scatter_add_like

* refine scatter sbp

* fail to sccater_add_like on 1n2d

* refing sbp

* refine test case, unify add and update like ops

* test case for scatter_add/update like ops finished

* test cases for scatter ops

* refine, merge test class

* startup of api docs

* add scatter api docs and assertion in python

* fix make error but still segment fault

* annotate sbp infer

* rewrite scatter kernel logic

* remove inplace proposal and fix macro name

* remove outdated atomic add

* move sbp infer

* add const and throw error

* add check

* set grad op

* add scatter scalar

* add scatter scalar gpu kernel

* add torch style backprop

* add torch style backprop check

* align with master

* remove redundant sbp check

* add test

* add float16n register

* fix sbp

* fix sbp

* add api doc

* make format

* add new line

* refine

* revert dim gather

* extract dim_scatter_add

* extracat scatter update ops

* add add/update functor

* rewrting by functors

* refine

* remove dim_gather_scatter_uitl.h

* add blank line

* refine macros for registering kerenls

* refine dim_scatter_scalar files name

* refine

* refine register ops

* refine

* add F.dim_scatter_scalar

* add scatter op

* refine docstr

* add scatter reduce arg

* finally(!): a draft for scatter constitent with pytroch

* change import package name

* remmove lazy test and add scatter_add and scatter_mul

* startup of scatter backward op

* add backward for scatter

* scatter ops backward finished

* add scatter, scatter_add test cases

* remove useless scatter_update_like

* reformat

* refine test cases

* refine according to comments

* revert op_exprt_helper

* fixed index element

* fix scatter update like expr for dim gather backward
Co-authored-by: Ndoombeaker <later@usopp.net>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 0a44b54e
......@@ -29,6 +29,8 @@ oneflow
reshape,
save,
saved_model,
scatter,
scatter_add,
scatter_nd,
selu,
silu,
......
......@@ -72,7 +72,7 @@ Maybe<void> DimGather::Apply(const DimGatherInterpState* ctx, const TensorTuple&
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", ctx->dim));
in_grads->at(0) = JUST(
OpInterpUtil::Dispatch<Tensor>(*bw_dim_gather_op_, {like, out_grads.at(0), index}, attrs));
OpInterpUtil::Dispatch<Tensor>(*bw_dim_gather_op_, {like, index, out_grads.at(0)}, attrs));
return Maybe<void>::Ok();
}
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_expr_helper.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/functional/functional.h"
namespace oneflow {
namespace one {
struct DimScatterInterpState : public OpExprInterpState {
int32_t dim;
bool input_requires_grad;
bool src_requires_grad;
};
enum SCATTER_TYPE { SCATTER_UPDATE, SCATTER_ADD };
template<SCATTER_TYPE T>
class DimScatter : public OpExprGradFunction<DimScatterInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
Maybe<void> ApplyCommon(const DimScatterInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const;
private:
AttrMap base_attrs_;
};
template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 3);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->src_requires_grad = inputs.at(2)->requires_grad();
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(1)); // index saved
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}
template<SCATTER_TYPE T>
Maybe<void> DimScatter<T>::ApplyCommon(const DimScatterInterpState* ctx,
const TensorTuple& out_grads, TensorTuple* in_grads) const {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->resize(3);
if (ctx->src_requires_grad) {
in_grads->at(2) = JUST(functional::DimGather(out_grads.at(0), index, ctx->dim));
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_UPDATE>::Apply(const DimScatterInterpState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
JUST(ApplyCommon(ctx, out_grads, in_grads));
if (ctx->input_requires_grad) {
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), index, 0.0f, ctx->dim));
}
return Maybe<void>::Ok();
}
template<>
Maybe<void> DimScatter<SCATTER_TYPE::SCATTER_ADD>::Apply(const DimScatterInterpState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if ((!ctx->input_requires_grad) && (!ctx->src_requires_grad)) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
JUST(ApplyCommon(ctx, out_grads, in_grads));
if (ctx->input_requires_grad) { in_grads->at(0) = out_grads.at(0); }
return Maybe<void>::Ok();
}
class DimScatterUpdateScalar : public OpExprGradFunction<DimScatterInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DimScatterInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
private:
AttrMap base_attrs_;
};
Maybe<void> DimScatterUpdateScalar::Init(const OpExpr& op) {
const UserOpExpr* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr);
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}
Maybe<void> DimScatterUpdateScalar::Capture(DimScatterInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(1)); // index saved
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->dim = JUST(composed_attrs.GetAttr<int32_t>("dim"));
return Maybe<void>::Ok();
}
Maybe<void> DimScatterUpdateScalar::Apply(const DimScatterInterpState* ctx,
const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->input_requires_grad) { return Maybe<void>::Ok(); }
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
const std::shared_ptr<oneflow::one::Tensor>& index = ctx->SavedTensors().at(0);
in_grads->resize(2);
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", ctx->dim));
JUST(attrs.SetAttr<float>("src_scalar", 0.0f));
in_grads->at(0) =
JUST(functional::DimScatterUpdateScalar(out_grads.at(0), index, 0.0f, ctx->dim););
return Maybe<void>::Ok();
}
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update", DimScatter<SCATTER_TYPE::SCATTER_UPDATE>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_add", DimScatter<SCATTER_TYPE::SCATTER_ADD>);
REGISTER_OP_EXPR_GRAD_FUNCTION("dim_scatter_update_scalar", DimScatterUpdateScalar);
} // namespace one
} // namespace oneflow
......@@ -637,8 +637,8 @@ Maybe<one::UserOpExpr> DimScatterAddLikeOp(const int32_t dim) {
Maybe<one::UserOpExpr> DimScatterAddLikeOp(const int32_t dim, const std::string& name) {
return one::OpBuilder("dim_scatter_add_like", name)
.Input("like")
.Input("input")
.Input("index")
.Input("src")
.Output("output")
.Attr<int32_t>("dim", dim)
.Build();
......
......@@ -775,6 +775,22 @@
signature: "Tensor TensorGetItem(Tensor x, *, TensorIndex index)"
bind_python: True
- name: "dim_scatter"
signature: "Tensor DimScatter(Tensor input, Tensor index, Tensor src, *, Int32 dim)"
bind_python: True
- name: "dim_scatter_add"
signature: "Tensor DimScatterAdd(Tensor input, Tensor index, Tensor src, *, Int32 dim)"
bind_python: True
- name: "dim_scatter_scalar"
signature: "Tensor DimScatterUpdateScalar(Tensor input, Tensor index, *, Float src, Int32 dim)"
bind_python: True
- name: "dim_scatter_add_scalar"
signature: "Tensor DimScatterAddScalar(Tensor input, Tensor index, *, Float src, Int32 dim)"
bind_python: True
- name: "tensor_setitem"
signature: "Void TensorSetItem(Tensor x, *, TensorIndex index, Tensor value)"
bind_python: True
......
......@@ -344,6 +344,138 @@ class DimGatherFunctor {
std::shared_ptr<OpExpr> op_;
};
class DimScatterFunctor {
public:
DimScatterFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_update")
.Input("input")
.Input("index")
.Input("src")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index,
const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class DimScatterAddFunctor {
public:
DimScatterAddFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_add")
.Input("input")
.Input("index")
.Input("src")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index,
const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class DimScatterMulFunctor {
public:
DimScatterMulFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_mul")
.Input("input")
.Input("index")
.Input("src")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index,
const std::shared_ptr<one::Tensor>& src, const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index, src}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class DimScatterUpdateScalarFunctor {
public:
DimScatterUpdateScalarFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_update_scalar")
.Input("input")
.Input("index")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index, const float& src,
const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
JUST(attrs.SetAttr<float>("src_scalar", src));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class DimScatterAddScalarFunctor {
public:
DimScatterAddScalarFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_add_scalar")
.Input("input")
.Input("index")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index, const float& src,
const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
JUST(attrs.SetAttr<float>("src_scalar", src));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class DimScatterMulScalarFunctor {
public:
DimScatterMulScalarFunctor() {
op_ = CHECK_JUST(one::OpBuilder("dim_scatter_mul_scalar")
.Input("input")
.Input("index")
.Output("output")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& input,
const std::shared_ptr<one::Tensor>& index, const float& src,
const int32_t& dim) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("dim", dim));
JUST(attrs.SetAttr<float>("src_scalar", src));
return OpInterpUtil::Dispatch<Tensor>(*op_, {input, index}, attrs);
}
private:
std::shared_ptr<OpExpr> op_;
};
class GatherNdFunctor {
public:
GatherNdFunctor() {
......@@ -1153,6 +1285,12 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::DiagFunctor>("Diag");
m.add_functor<impl::DiagGradFunctor>("DiagGrad");
m.add_functor<impl::TensorGetItemFunctor>("TensorGetItem");
m.add_functor<impl::DimScatterFunctor>("DimScatter");
m.add_functor<impl::DimScatterAddFunctor>("DimScatterAdd");
m.add_functor<impl::DimScatterMulFunctor>("DimScatterMul");
m.add_functor<impl::DimScatterUpdateScalarFunctor>("DimScatterUpdateScalar");
m.add_functor<impl::DimScatterAddScalarFunctor>("DimScatterAddScalar");
m.add_functor<impl::DimScatterMulScalarFunctor>("DimScatterMulScalar");
m.add_functor<impl::TensorSetItemFunctor>("TensorSetItem");
m.add_functor<impl::ElementwiseMinimumGradFunctor>("ElementwiseMinGrad");
m.add_functor<impl::ElementwiseMaximumGradFunctor>("ElementwiseMaxGrad");
......
......@@ -30,20 +30,7 @@ struct DimGatherFunctor<DeviceType::kCPU, IN_T, IDX_T> final {
}
};
template<typename IN_T, typename IDX_T>
struct DimScatterAddFunctor<DeviceType::kCPU, IN_T, IDX_T> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) {
DoDimScatterAdd<IN_T, IDX_T>(input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index,
input, output);
}
};
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kCPU),
DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_SCATTER_ADD_FUNCTOR, (DeviceType::kCPU),
DIM_GATHER_SCATTER_DATA_TYPE_CPU_SEQ, INDEX_DATA_TYPE_SEQ);
} // namespace user_op
} // namespace oneflow
......@@ -13,7 +13,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cstdint>
#ifdef WITH_CUDA
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/dim_gather_kernel_util.h"
......@@ -53,41 +52,8 @@ struct DimGatherFunctor<DeviceType::kGPU, float16, IDX_T> final {
}
};
template<typename IN_T, typename IDX_T>
__global__ void DoCUDAScatterDimAdd(const DimOpIndexNdHelper<IDX_T> input_nd_helper,
const DimOpIndexNdHelper<IDX_T> output_nd_helper, int ndim,
int64_t elem_cnt, int32_t dim, const IDX_T* index,
const IN_T* input, IN_T* output) {
DoDimScatterAdd<IN_T, IDX_T>(input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index, input,
output);
}
template<typename IN_T, typename IDX_T>
struct DimScatterAddFunctor<DeviceType::kGPU, IN_T, IDX_T> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const IN_T* input, IN_T* output) {
RUN_CUDA_KERNEL((DoCUDAScatterDimAdd<IN_T, IDX_T>), ctx, BlocksNum4ThreadsNum(elem_cnt),
input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index, input, output);
}
};
// float16 special case of DimScatterAddFunctor template
template<typename IDX_T>
struct DimScatterAddFunctor<DeviceType::kGPU, float16, IDX_T> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& input_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, int ndim, int64_t elem_cnt,
int32_t dim, const IDX_T* index, const float16* input, float16* output) {
RUN_CUDA_KERNEL((DoCUDAScatterDimAdd<half, IDX_T>), ctx, BlocksNum4ThreadsNum(elem_cnt),
input_nd_helper, output_nd_helper, ndim, elem_cnt, dim, index,
reinterpret_cast<const half*>(input), reinterpret_cast<half*>(output));
}
};
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_GATHER_FUNCTOR, (DeviceType::kGPU),
DIM_GATHER_SCATTER_DATA_TYPE_GPU_SEQ, INDEX_DATA_TYPE_SEQ);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_DIM_SCATTER_ADD_FUNCTOR, (DeviceType::kGPU),
DIM_GATHER_SCATTER_DATA_TYPE_GPU_SEQ, INDEX_DATA_TYPE_SEQ);
} // namespace user_op
} // namespace oneflow
......
......@@ -99,10 +99,6 @@ OF_DEVICE_FUNC void DoDimScatterAdd(const DimOpIndexNdHelper<IDX_T>& input_nd_he
template struct DimGatherFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \
OF_PP_PAIR_FIRST(itype_pair)>;
#define INSTANTIATE_DIM_SCATTER_ADD_FUNCTOR(device_type_v, dtype_pair, itype_pair) \
template struct DimScatterAddFunctor<device_type_v, OF_PP_PAIR_FIRST(dtype_pair), \
OF_PP_PAIR_FIRST(itype_pair)>;
} // namespace user_op
} // namespace oneflow
......
......@@ -66,44 +66,6 @@ class DimGatherKernel final : public user_op::OpKernel {
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
template<DeviceType device_type, typename IN_T, typename IDX_T>
class ScatterDimKernel final : public user_op::OpKernel {
public:
ScatterDimKernel() = default;
~ScatterDimKernel() override = default;
private:
void Compute(KernelComputeContext* ctx) const override {
const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0);
const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0);
Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0);
const int32_t dim = ctx->Attr<int32_t>("dim");
const IN_T* src = input_tensor->dptr<IN_T>();
const IDX_T* index = index_tensor->dptr<IDX_T>();
IN_T* output = out_tensor->mut_dptr<IN_T>();
size_t out_bytes_size =
out_tensor->shape().elem_cnt() * GetSizeOfDataType(out_tensor->data_type());
Memset<device_type>(ctx->device_ctx(), output, 0, out_bytes_size);
int ndim = input_tensor->shape().NumAxes();
fixed_vector<IDX_T, kDimGatherMaxDimCount> shape_vec(ndim);
auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void {
std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(),
[](int64_t dim) -> IDX_T { return static_cast<IDX_T>(dim); });
};
shape2dims(input_tensor->shape());
DimOpIndexNdHelper<IDX_T> input_nd_helper(shape_vec.data(), ndim);
shape2dims(out_tensor->shape());
DimOpIndexNdHelper<IDX_T> output_nd_helper(shape_vec.data(), ndim);
DimScatterAddFunctor<device_type, IN_T, IDX_T>()(
ctx->device_ctx(), input_nd_helper, output_nd_helper, ndim,
input_tensor->shape().elem_cnt(), dim, index, src, output);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_DIM_GATHER_KERNEL(device, dtype, itype) \
REGISTER_USER_KERNEL("dim_gather") \
.SetCreateFn<DimGatherKernel<device, dtype, itype>>() \
......@@ -111,13 +73,6 @@ class ScatterDimKernel final : public user_op::OpKernel {
& (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("index", 0) == GetDataType<itype>::value));
#define REGISTER_DIM_SCATTER_KERNEL(device, dtype, itype) \
REGISTER_USER_KERNEL("dim_scatter_add_like") \
.SetCreateFn<ScatterDimKernel<device, dtype, itype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("index", 0) == GetDataType<itype>::value));
#define REGISTER_DIM_GATHER_KERNELS_WITH_DEVICE(device) \
REGISTER_DIM_GATHER_KERNEL(device, float, int32_t) \
REGISTER_DIM_GATHER_KERNEL(device, double, int32_t) \
......@@ -126,23 +81,11 @@ class ScatterDimKernel final : public user_op::OpKernel {
REGISTER_DIM_GATHER_KERNEL(device, double, int64_t) \
REGISTER_DIM_GATHER_KERNEL(device, int32_t, int64_t)
#define REGISTER_DIM_SCATTER_ADD_LIKE_KERNELS_WITH_DEVICE(device) \
REGISTER_DIM_SCATTER_KERNEL(device, float, int32_t) \
REGISTER_DIM_SCATTER_KERNEL(device, double, int32_t) \
REGISTER_DIM_SCATTER_KERNEL(device, int32_t, int32_t) \
REGISTER_DIM_SCATTER_KERNEL(device, float, int64_t) \
REGISTER_DIM_SCATTER_KERNEL(device, double, int64_t) \
REGISTER_DIM_SCATTER_KERNEL(device, int32_t, int64_t)
REGISTER_DIM_GATHER_KERNELS_WITH_DEVICE(DeviceType::kCPU);
REGISTER_DIM_SCATTER_ADD_LIKE_KERNELS_WITH_DEVICE(DeviceType::kCPU);
#ifdef WITH_CUDA
REGISTER_DIM_GATHER_KERNELS_WITH_DEVICE(DeviceType::kGPU);
REGISTER_DIM_SCATTER_ADD_LIKE_KERNELS_WITH_DEVICE(DeviceType::kGPU);
REGISTER_DIM_GATHER_KERNEL(DeviceType::kGPU, float16, int32_t);
REGISTER_DIM_SCATTER_KERNEL(DeviceType::kGPU, float16, int32_t);
REGISTER_DIM_SCATTER_KERNEL(DeviceType::kGPU, float16, int64_t);
REGISTER_DIM_GATHER_KERNEL(DeviceType::kGPU, float16, int64_t);
#endif // WITH_CUDA
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/dim_scatter_kernel_util.h"
namespace oneflow {
namespace user_op {
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
struct DimScatterFunctor<DeviceType::kCPU, IN_T, IDX_T, Opt> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,
const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,
const IDX_T* index, const IN_T* src, IN_T* output) {
DoDimScatter<IN_T, IDX_T, Opt>(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt,
dim, upper_bound, index, src, output);
}
};
INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kCPU, BinOpAddFunctor);
INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kCPU, BinOpUpdateFunctor);
} // namespace user_op
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/user/kernels/dim_scatter_kernel_util.h"
namespace oneflow {
namespace user_op {
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
__global__ void DoCUDADimScatter(const DimOpIndexNdHelper<IDX_T> src_nd_helper,
const DimOpIndexNdHelper<IDX_T> idx_nd_helper,
const DimOpIndexNdHelper<IDX_T> output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim,
const int64_t upper_bound, const IDX_T* index, const IN_T* src,
IN_T* output) {
DoDimScatter<IN_T, IDX_T, Opt>(src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt,
dim, upper_bound, index, src, output);
}
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
struct DimScatterFunctor<DeviceType::kGPU, IN_T, IDX_T, Opt> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,
const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,
const IDX_T* index, const IN_T* src, IN_T* output) {
RUN_CUDA_KERNEL((DoCUDADimScatter<IN_T, IDX_T, Opt>), ctx, BlocksNum4ThreadsNum(elem_cnt),
src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,
upper_bound, index, src, output);
}
};
template<typename IDX_T, template<typename T> class Opt>
struct DimScatterFunctor<DeviceType::kGPU, float16, IDX_T, Opt> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,
const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,
const IDX_T* index, const float16* src, float16* output) {
RUN_CUDA_KERNEL((DoCUDADimScatter<half, IDX_T, Opt>), ctx, BlocksNum4ThreadsNum(elem_cnt),
src_nd_helper, idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,
upper_bound, index, reinterpret_cast<const half*>(src),
reinterpret_cast<half*>(output));
}
};
INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kGPU, BinOpAddFunctor);
INSTANTIATE_DIM_SCATTER_FUNCTORS(DeviceType::kGPU, BinOpUpdateFunctor);
} // namespace user_op
} // namespace oneflow
#endif // WITH_CUDA
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_
#define ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_
#ifdef WITH_CUDA
#include "oneflow/core/cuda/atomic.cuh"
#endif // WITH_CUDA
#include "oneflow/core/ndarray/xpu_util.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/error.pb.h"
namespace oneflow {
namespace user_op {
constexpr int kDimGatherMaxDimCount = 8;
template<typename T>
using DimOpIndexNdHelper = NdIndexOffsetHelper<T, kDimGatherMaxDimCount>;
#define INSTANTIATE_DIM_SCATTER_FUNCTORS(device_type, opt) \
template struct DimScatterFunctor<device_type, int32_t, int32_t, opt>; \
template struct DimScatterFunctor<device_type, float, int32_t, opt>; \
template struct DimScatterFunctor<device_type, double, int32_t, opt>; \
template struct DimScatterFunctor<device_type, int32_t, int64_t, opt>; \
template struct DimScatterFunctor<device_type, float, int64_t, opt>; \
template struct DimScatterFunctor<device_type, double, int64_t, opt>;
template<typename T>
struct BinOpAddFunctor {
OF_DEVICE_FUNC static void apply(const T* x, T* y) {
#ifdef __CUDA_ARCH__
cuda::atomic::Add(y, *x);
#else
*y += *x;
#endif
}
};
template<typename T>
struct BinOpUpdateFunctor {
OF_DEVICE_FUNC static void apply(const T* x, T* y) { *y = *x; }
};
template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>
struct DimScatterFunctor final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& src_nd_helper,
const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, const int64_t upper_bound,
const IDX_T* index, const IN_T* src, IN_T* output);
};
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
OF_DEVICE_FUNC void DoDimScatter(const DimOpIndexNdHelper<IDX_T>& src_nd_helper,
const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,
const IDX_T* index, const IN_T* src, IN_T* output) {
XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) {
IDX_T coordinate[kDimGatherMaxDimCount] = {0};
idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim); // idx_offset -> ijk
IDX_T idx_elem = index[idx_offset];
if (idx_elem >= upper_bound) {
#if __CUDA_ARCH__
__trap();
#else
std::cout << "The index element " << idx_elem << " is out of bounds for dimension " << dim
<< " with size " << upper_bound << std::endl;
throw Error::CheckFailedError();
#endif
}
IDX_T src_offset = src_nd_helper.NdIndexToOffset(coordinate, ndim);
coordinate[dim] = idx_elem;
IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim);
Opt<IN_T>::apply(src + src_offset, output + output_offset);
}
}
} // namespace user_op
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DIM_SCATTER_KERNEL_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/error.pb.h"
#include "oneflow/core/common/util.h"
#include "oneflow/user/kernels/dim_scatter_kernel_util.h"
namespace oneflow {
namespace user_op {
template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>
class DimScatterKernel final : public user_op::OpKernel {
public:
DimScatterKernel() = default;
~DimScatterKernel() override = default;
private:
void Compute(KernelComputeContext* ctx) const override {
const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0);
const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0);
Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0);
const Tensor* src_tensor = ctx->Tensor4ArgNameAndIndex("src", 0);
const int32_t dim = ctx->Attr<int32_t>("dim");
const IDX_T* index = index_tensor->dptr<IDX_T>();
IN_T* output = out_tensor->mut_dptr<IN_T>();
size_t out_bytes_size =
out_tensor->shape().elem_cnt() * GetSizeOfDataType(out_tensor->data_type());
Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex("like", 0);
const IN_T* src = src_tensor->dptr<IN_T>();
if (input_tensor) {
Memcpy<device_type>(ctx->device_ctx(), output, input_tensor->dptr<IN_T>(), out_bytes_size);
} else if (like_tensor) {
Memset<device_type>(ctx->device_ctx(), output, 0, out_bytes_size);
} else {
std::cout << "Unimplemented Error" << std::endl;
throw Error::Unimplemented();
}
const int ndim = src_tensor->shape().NumAxes();
fixed_vector<IDX_T, kDimGatherMaxDimCount> shape_vec(ndim);
auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void {
std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(),
[](int32_t dim) -> IDX_T { return static_cast<IDX_T>(dim); });
};
shape2dims(src_tensor->shape());
DimOpIndexNdHelper<IDX_T> src_nd_helper(shape_vec.data(), ndim);
shape2dims(index_tensor->shape());
DimOpIndexNdHelper<IDX_T> idx_nd_helper(shape_vec.data(), ndim);
shape2dims(out_tensor->shape());
DimOpIndexNdHelper<IDX_T> output_nd_helper(shape_vec.data(), ndim);
int64_t upper_bound = 0;
if (input_tensor) {
upper_bound = input_tensor->shape().At(dim); // ensure the idx is smaller than upperbound
} else {
upper_bound = like_tensor->shape().At(dim); // ensure the idx is smaller than upperbound
}
DimScatterFunctor<device_type, IN_T, IDX_T, Opt>()(
ctx->device_ctx(), src_nd_helper, idx_nd_helper, output_nd_helper, ndim,
index_tensor->shape().elem_cnt(), dim, upper_bound, index, src, output);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, device, dtype, itype, opt) \
REGISTER_USER_KERNEL(op_type) \
.SetCreateFn<DimScatterKernel<device, dtype, itype, opt>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("like", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("index", 0) == GetDataType<itype>::value));
#define REGISTER_DIM_SCATTER_LIKE_CPU_KERNELS(op_type, opt) \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int32_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int32_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int32_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, float, int64_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, double, int64_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kCPU, int32_t, int64_t, opt);
#define REGISTER_DIM_SCATTER_LIKE_GPU_KERNELS(op_type, opt) \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, float, int32_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, double, int32_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, int32_t, int32_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, float, int64_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, double, int64_t, opt); \
REGISTER_DIM_SCATTER_LIKE_KERNEL(op_type, DeviceType::kGPU, int32_t, int64_t, opt);
#define REGISTER_DIM_SCATTER_KERNEL(op_type, device, dtype, itype, opt) \
REGISTER_USER_KERNEL(op_type) \
.SetCreateFn<DimScatterKernel<device, dtype, itype, opt>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("index", 0) == GetDataType<itype>::value));
#define REGISTER_DIM_SCATTER_CPU_KERNELS(op_type, opt) \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, float, int32_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, double, int32_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, int32_t, int32_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, float, int64_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, double, int64_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kCPU, int32_t, int64_t, opt);
#define REGISTER_DIM_SCATTER_GPU_KERNELS(op_type, opt) \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, float, int32_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, double, int32_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, int32_t, int32_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, float, int64_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, double, int64_t, opt); \
REGISTER_DIM_SCATTER_KERNEL(op_type, DeviceType::kGPU, int32_t, int64_t, opt);
REGISTER_DIM_SCATTER_LIKE_CPU_KERNELS("dim_scatter_add_like", BinOpAddFunctor);
REGISTER_DIM_SCATTER_CPU_KERNELS("dim_scatter_add", BinOpAddFunctor);
REGISTER_DIM_SCATTER_CPU_KERNELS("dim_scatter_update", BinOpUpdateFunctor);
#ifdef WITH_CUDA
REGISTER_DIM_SCATTER_LIKE_GPU_KERNELS("dim_scatter_add_like", BinOpAddFunctor);
REGISTER_DIM_SCATTER_GPU_KERNELS("dim_scatter_add", BinOpAddFunctor);
REGISTER_DIM_SCATTER_GPU_KERNELS("dim_scatter_update", BinOpUpdateFunctor);
#endif // WITH_CUDA
} // namespace user_op
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h"
namespace oneflow {
namespace user_op {
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
struct DimScatterScalarFunctor<DeviceType::kCPU, IN_T, IDX_T, Opt> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,
const IDX_T* index, const IN_T src, IN_T* output) {
DoScatterScalarFunctor<IN_T, IDX_T, Opt>(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,
upper_bound, index, src, output);
}
};
INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kCPU, UpdateScalarFunctor);
INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kCPU, AddScalarFunctor);
} // namespace user_op
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifdef WITH_CUDA
#include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h"
namespace oneflow {
namespace user_op {
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
__global__ void DoCUDADimScatterScalar(const DimOpIndexNdHelper<IDX_T> idx_nd_helper,
const DimOpIndexNdHelper<IDX_T> output_nd_helper,
const int ndim, const int64_t elem_cnt, const int32_t dim,
const int64_t upper_bound, const IDX_T* index,
const IN_T src_scalar, IN_T* output) {
DoScatterScalarFunctor<IN_T, IDX_T, Opt>(idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim,
upper_bound, index, src_scalar, output);
}
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
struct DimScatterScalarFunctor<DeviceType::kGPU, IN_T, IDX_T, Opt> final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,
const IDX_T* index, const IN_T src, IN_T* output) {
RUN_CUDA_KERNEL((DoCUDADimScatterScalar<IN_T, IDX_T, Opt>), ctx, BlocksNum4ThreadsNum(elem_cnt),
idx_nd_helper, output_nd_helper, ndim, elem_cnt, dim, upper_bound, index, src,
output);
}
};
INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kGPU, UpdateScalarFunctor);
INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(DeviceType::kGPU, AddScalarFunctor);
} // namespace user_op
} // namespace oneflow
#endif
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_
#define ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_
#ifdef WITH_CUDA
#include "oneflow/core/cuda/atomic.cuh"
#endif // WITH_CUDA
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/ndarray/xpu_util.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/common/data_type.h"
namespace oneflow {
namespace user_op {
constexpr int kDimGatherMaxDimCount = 8;
template<typename T>
struct AddScalarFunctor {
OF_DEVICE_FUNC static void apply(const T x, T* y) {
#ifdef __CUDA_ARCH__
cuda::atomic::Add(y, x);
#else
*y += x;
#endif
}
};
template<typename T>
struct UpdateScalarFunctor {
OF_DEVICE_FUNC static void apply(const T x, T* y) { *y = x; }
};
#define INSTANTIATE_DIM_SCATTER_SCARLAR_FUNCTORS(device_type, opt) \
template struct DimScatterScalarFunctor<device_type, int32_t, int32_t, opt>; \
template struct DimScatterScalarFunctor<device_type, float, int32_t, opt>; \
template struct DimScatterScalarFunctor<device_type, double, int32_t, opt>; \
template struct DimScatterScalarFunctor<device_type, int32_t, int64_t, opt>; \
template struct DimScatterScalarFunctor<device_type, float, int64_t, opt>; \
template struct DimScatterScalarFunctor<device_type, double, int64_t, opt>;
template<typename T>
using DimOpIndexNdHelper = NdIndexOffsetHelper<T, kDimGatherMaxDimCount>;
template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>
struct DimScatterScalarFunctor final {
void operator()(DeviceCtx* ctx, const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper, const int ndim,
const int64_t elem_cnt, const int32_t dim, int64_t upper_bound,
const IDX_T* index, const IN_T src, IN_T* output);
};
template<typename IN_T, typename IDX_T, template<typename T> class Opt>
OF_DEVICE_FUNC void DoScatterScalarFunctor(const DimOpIndexNdHelper<IDX_T>& idx_nd_helper,
const DimOpIndexNdHelper<IDX_T>& output_nd_helper,
const int ndim, const int64_t elem_cnt,
const int32_t dim, int64_t upper_bound,
const IDX_T* index, const IN_T src, IN_T* output) {
XPU_1D_KERNEL_LOOP(idx_offset, elem_cnt) {
IDX_T coordinate[kDimGatherMaxDimCount] = {0};
idx_nd_helper.OffsetToNdIndex(idx_offset, coordinate, ndim); // idx_offset -> ijk
IDX_T idx_elem = index[idx_offset];
if (idx_elem >= upper_bound) {
#if __CUDA_ARCH__
__trap();
#else
std::cout << "The index element " << idx_elem << " is out of bounds for dimension " << dim
<< " with size " << upper_bound << std::endl;
throw Error::CheckFailedError();
#endif
}
coordinate[dim] = idx_elem;
IDX_T output_offset = output_nd_helper.NdIndexToOffset(coordinate, ndim);
Opt<IN_T>::apply(src, output + output_offset);
}
}
} // namespace user_op
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DIM_SCATTER_SCALAR_KERNEL_UTIL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/dim_scatter_scalar_kernel_util.h"
namespace oneflow {
namespace user_op {
template<DeviceType device_type, typename IN_T, typename IDX_T, template<typename T> class Opt>
class DimScatterScalarKernel final : public user_op::OpKernel {
public:
DimScatterScalarKernel() = default;
~DimScatterScalarKernel() = default;
private:
void Compute(KernelComputeContext* ctx) const override {
const Tensor* input_tensor = ctx->Tensor4ArgNameAndIndex("input", 0);
const Tensor* index_tensor = ctx->Tensor4ArgNameAndIndex("index", 0);
Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("output", 0);
const int32_t dim = ctx->Attr<int32_t>("dim");
const IDX_T* index = index_tensor->dptr<IDX_T>();
IN_T* output = out_tensor->mut_dptr<IN_T>();
size_t out_bytes_size =
out_tensor->shape().elem_cnt() * GetSizeOfDataType(out_tensor->data_type());
Tensor* like_tensor = ctx->Tensor4ArgNameAndIndex("like", 0);
const IN_T src_scalar = static_cast<IN_T>(ctx->Attr<float>("src_scalar"));
if (input_tensor) {
Memcpy<device_type>(ctx->device_ctx(), output, input_tensor->dptr<IN_T>(), out_bytes_size);
} else if (like_tensor) {
Memset<device_type>(ctx->device_ctx(), output, 0, out_bytes_size);
} else {
std::cout << "Unimplemented Error" << std::endl;
throw Error::Unimplemented();
}
const int ndim = out_tensor->shape().NumAxes();
fixed_vector<IDX_T, kDimGatherMaxDimCount> shape_vec(ndim);
auto shape2dims = [&shape_vec, &ndim](const ShapeView& tensor_shape) -> void {
std::transform(tensor_shape.ptr(), tensor_shape.ptr() + ndim, shape_vec.begin(),
[](int32_t dim) -> IDX_T { return static_cast<IDX_T>(dim); });
};
shape2dims(index_tensor->shape());
DimOpIndexNdHelper<IDX_T> idx_nd_helper(shape_vec.data(), ndim);
shape2dims(out_tensor->shape());
DimOpIndexNdHelper<IDX_T> output_nd_helper(shape_vec.data(), ndim);
int64_t upper_bound = input_tensor->shape().At(dim);
DimScatterScalarFunctor<device_type, IN_T, IDX_T, Opt>()(
ctx->device_ctx(), idx_nd_helper, output_nd_helper, ndim, index_tensor->shape().elem_cnt(),
dim, upper_bound, index, src_scalar, output);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
#define REGISTER_SCATTERSCALAR_KERNEL(op_type_name, device, dtype, itype, opt) \
REGISTER_USER_KERNEL(op_type_name) \
.SetCreateFn<DimScatterScalarKernel<device, dtype, itype, opt>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobDataType("input", 0) == GetDataType<dtype>::value) \
& (user_op::HobDataType("index", 0) == GetDataType<itype>::value));
#define REGISTER_SCATTER_SCALAR_CPU_KERNELS(op_type_name, opt) \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, float, int32_t, opt); \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, float, int64_t, opt); \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, double, int32_t, opt); \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kCPU, double, int64_t, opt);
#define REGISTER_SCATTER_SCALAR_GPU_KERNELS(op_type_name, opt) \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, float, int32_t, opt); \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, float, int64_t, opt); \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, double, int32_t, opt); \
REGISTER_SCATTERSCALAR_KERNEL(op_type_name, DeviceType::kGPU, double, int64_t, opt);
REGISTER_SCATTER_SCALAR_CPU_KERNELS("dim_scatter_update_scalar", UpdateScalarFunctor);
REGISTER_SCATTER_SCALAR_CPU_KERNELS("dim_scatter_add_scalar", AddScalarFunctor);
#ifdef WITH_CUDA
REGISTER_SCATTER_SCALAR_GPU_KERNELS("dim_scatter_update_scalar", UpdateScalarFunctor);
REGISTER_SCATTER_SCALAR_GPU_KERNELS("dim_scatter_add_scalar", AddScalarFunctor);
#endif // WITH_CUDA
} // namespace user_op
} // namespace oneflow
......@@ -17,8 +17,8 @@ limitations under the License.
#include "oneflow/user/kernels/dim_gather_kernel_util.h"
namespace oneflow {
namespace user_op {
REGISTER_USER_OP("dim_gather")
.Input("input")
.Input("index")
......@@ -40,11 +40,6 @@ REGISTER_USER_OP("dim_gather")
CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic());
FOR_RANGE(int64_t, i, 0, input_num_axes) {
if (i == dim) { continue; }
CHECK_EQ_OR_RETURN(in.shape().At(i), index.shape().At(i));
}
user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0);
*out->mut_shape() = index.shape();
......@@ -86,94 +81,10 @@ REGISTER_USER_OP("dim_gather")
.Build();
}
}
ctx->NewBuilder()
.PartialSum(user_op::OpArg("input", 0))
.Broadcast(user_op::OpArg("index", 0))
.PartialSum(user_op::OpArg("output", 0))
.Build();
return Maybe<void>::Ok();
});
REGISTER_USER_OP("dim_scatter_add_like")
.Input("like")
.Input("input")
.Input("index")
.Output("output")
.Attr<int32_t>("dim")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const TensorDesc& input = ctx->InputTensorDesc("input", 0);
const TensorDesc& index = ctx->InputTensorDesc("index", 0);
const TensorDesc& like = ctx->InputTensorDesc("like", 0);
const Shape& like_shape = like.shape();
int64_t input_num_axes = input.shape().NumAxes();
CHECK_GT_OR_RETURN(input_num_axes, 0);
CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount);
int64_t index_num_axes = index.shape().NumAxes();
CHECK_EQ_OR_RETURN(input_num_axes, index_num_axes);
CHECK_EQ_OR_RETURN(input_num_axes, like_shape.NumAxes());
FOR_RANGE(int64_t, i, 0, input_num_axes) {
CHECK_EQ_OR_RETURN(index.shape().At(i), input.shape().At(i));
}
user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0);
*out->mut_shape() = like_shape;
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
const TensorDesc& input = ctx->InputTensorDesc("input", 0);
user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0);
*out->mut_data_type() = input.data_type();
return Maybe<void>::Ok();
})
.SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn,
const user_op::UserOpConfWrapper&) -> Maybe<void> {
user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0);
CHECK_OR_RETURN(like_arg_modifier != nullptr);
like_arg_modifier->set_requires_grad(false);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
const user_op::TensorDesc& index_tensor =
ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0);
int64_t index_num_axes = index_tensor.shape().NumAxes();
const int32_t dim = ctx->Attr<int32_t>("dim");
FOR_RANGE(int64_t, i, 0, index_num_axes) {
if (i != dim) {
ctx->NewBuilder()
.Split(user_op::OpArg("index", 0), i)
.Split(user_op::OpArg("input", 0), i)
.Split(user_op::OpArg("output", 0), i)
.Split(user_op::OpArg("like", 0), i)
.Build();
} else {
ctx->NewBuilder()
.Split(user_op::OpArg("index", 0), i)
.Split(user_op::OpArg("input", 0), i)
.PartialSum(user_op::OpArg("output", 0))
.Broadcast(user_op::OpArg("like", 0))
.Build();
ctx->NewBuilder()
.Split(user_op::OpArg("index", 0), i)
.Split(user_op::OpArg("input", 0), i)
.PartialSum(user_op::OpArg("output", 0))
.PartialSum(user_op::OpArg("like", 0))
.Build();
}
}
ctx->NewBuilder()
.PartialSum(user_op::OpArg("input", 0))
.Broadcast(user_op::OpArg("index", 0))
.PartialSum(user_op::OpArg("output", 0))
.PartialSum(user_op::OpArg("like", 0))
.Build();
return Maybe<void>::Ok();
});
......@@ -185,10 +96,10 @@ REGISTER_USER_OP_GRAD("dim_gather")
ctx->DefineOp(op_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) {
return builder
.OpTypeName(
"dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, input) -> output
"dim_scatter_add_like") // dim_scatter_add_like(like, dim, index, src) -> output
.InputBind("index", ctx->FwOp().input("index", 0)) // scatter.index <- gather.index
.InputBind("input",
ctx->FwOp().output_grad("output", 0)) // scatter.input <- grad of gather.out
.InputBind("src",
ctx->FwOp().output_grad("output", 0)) // scatter.src <- grad of gather.out
.InputBind("like", ctx->FwOp().input("input", 0))
.Output("output")
.Attr("dim", ctx->FwOp().attr<int32_t>("dim"))
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/error.h"
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/framework/user_op_registry.h"
#include "oneflow/user/kernels/dim_scatter_kernel_util.h"
namespace oneflow {
namespace user_op {
namespace {
Maybe<void> InferTensorDesc(user_op::InferContext* ctx) {
const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0);
const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0);
const TensorDesc* like = ctx->TensorDesc4ArgNameAndIndex("like", 0);
const TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("src", 0);
int32_t dim = ctx->Attr<int32_t>("dim");
// check index.numaxes == src.num_axes == input/like.numaxes
int64_t src_num_axes = src->shape().NumAxes();
CHECK_GT_OR_RETURN(src_num_axes, 0);
CHECK_LE_OR_RETURN(src_num_axes, kDimGatherMaxDimCount);
int64_t index_num_axes = index->shape().NumAxes();
CHECK_EQ_OR_RETURN(src_num_axes, index_num_axes);
int64_t output_num_axes = 0;
if (input) {
output_num_axes = input->shape().NumAxes();
} else if (like) {
output_num_axes = like->shape().NumAxes();
} else {
throw Error::Unimplemented();
}
CHECK_EQ_OR_RETURN(output_num_axes, index_num_axes);
// check index.shape(i) <= input/like.shape(i)
FOR_RANGE(int64_t, i, 0, index_num_axes) {
if (i == dim) continue;
if (input) {
CHECK_LE_OR_RETURN(index->shape().At(i), input->shape().At(i));
} else {
CHECK_LE_OR_RETURN(index->shape().At(i), like->shape().At(i));
}
}
// check index.shape(i) <= src.shape(i)
FOR_RANGE(int64_t, i, 0, index_num_axes) {
if (i == dim) continue;
CHECK_LE_OR_RETURN(index->shape().At(i), src->shape().At(i));
}
user_op::TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("output", 0);
*out->mut_shape() = input ? input->shape() : like->shape();
return Maybe<void>::Ok();
}
Maybe<void> InferScalarTensorDesc(user_op::InferContext* ctx) {
const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0);
const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0);
int32_t dim = ctx->Attr<int32_t>("dim");
// check index.numaxes == src.num_axes == input/like.numaxes
int64_t output_num_axes = input->shape().NumAxes();
int64_t index_num_axes = index->shape().NumAxes();
CHECK_EQ_OR_RETURN(output_num_axes, index_num_axes);
// check index.shape(i) <= input/like.shape(i)
FOR_RANGE(int64_t, i, 0, index_num_axes) {
if (i == dim) continue;
CHECK_LE_OR_RETURN(index->shape().At(i), input->shape().At(i));
}
TensorDesc* out = ctx->TensorDesc4ArgNameAndIndex("output", 0);
*out->mut_shape() = input->shape();
return Maybe<void>::Ok();
}
Maybe<void> InputArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn,
const user_op::UserOpConfWrapper&) {
user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0);
CHECK(indices_modifier != nullptr);
indices_modifier->set_requires_grad(false);
return Maybe<void>::Ok();
}
Maybe<void> InputScalarArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn,
const user_op::UserOpConfWrapper&) {
user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0);
CHECK(indices_modifier != nullptr);
indices_modifier->set_requires_grad(false);
return Maybe<void>::Ok();
}
void _SetSbp(user_op::SbpContext* ctx, const char* like_or_input) {
const user_op::TensorDesc& index_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0);
int64_t index_num_axes = index_tensor.shape().NumAxes();
const int32_t dim = ctx->Attr<int32_t>("dim");
FOR_RANGE(int64_t, i, 0, index_num_axes) {
if (i != dim) {
ctx->NewBuilder()
.Split(user_op::OpArg("index", 0), i)
.Split(user_op::OpArg("src", 0), i)
.Split(user_op::OpArg("output", 0), i)
.Split(user_op::OpArg(like_or_input, 0), i)
.Build();
} else {
ctx->NewBuilder()
.Split(user_op::OpArg("index", 0), i)
.Split(user_op::OpArg("src", 0), i)
.PartialSum(user_op::OpArg("output", 0))
.Broadcast(user_op::OpArg(like_or_input, 0))
.Build();
ctx->NewBuilder()
.Split(user_op::OpArg("index", 0), i)
.Split(user_op::OpArg("src", 0), i)
.PartialSum(user_op::OpArg("output", 0))
.PartialSum(user_op::OpArg(like_or_input, 0))
.Build();
}
}
ctx->NewBuilder()
.PartialSum(user_op::OpArg("src", 0))
.Broadcast(user_op::OpArg("index", 0))
.PartialSum(user_op::OpArg("output", 0))
.PartialSum(user_op::OpArg(like_or_input, 0))
.Build();
}
Maybe<void> SetSbpLike(user_op::SbpContext* ctx) {
_SetSbp(ctx, "like");
return Maybe<void>::Ok();
}
Maybe<void> SetSbpScatter(user_op::SbpContext* ctx) {
_SetSbp(ctx, "input");
return Maybe<void>::Ok();
}
Maybe<void> InferDtype(user_op::InferContext* ctx) {
const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0);
CHECK_OR_RETURN(IsIndexDataType(index->data_type()));
const TensorDesc* input = ctx->TensorDesc4ArgNameAndIndex("input", 0);
if (input) {
CHECK_EQ_OR_RETURN(ctx->InputDType("input", 0), ctx->InputDType("src", 0));
} else {
CHECK_EQ_OR_RETURN(ctx->InputDType("like", 0), ctx->InputDType("src", 0));
}
*ctx->OutputDType("output", 0) = ctx->InputDType("src", 0);
return Maybe<void>::Ok();
}
Maybe<void> InferScalarDtype(user_op::InferContext* ctx) {
const TensorDesc* index = ctx->TensorDesc4ArgNameAndIndex("index", 0);
CHECK_OR_RETURN(IsIndexDataType(index->data_type()));
*ctx->OutputDType("output", 0) = ctx->InputDType("input", 0);
return Maybe<void>::Ok();
}
Maybe<void> ScatterBackward(user_op::BackwardOpConfContext* ctx) {
const TensorDesc& src = ctx->FwOp().TensorDesc4ArgNameAndIndex("src", 0);
const TensorDesc& index = ctx->FwOp().TensorDesc4ArgNameAndIndex("index", 0);
const int64_t ndim = src.shape().NumAxes();
FOR_RANGE(int64_t, i, 0, ndim) {
if (index.shape().At(i) != src.shape().At(i)) {
UNIMPLEMENTED() << "The backward pass is implemented only for src.shape == index.shape.\n";
}
}
const auto op_src_grad_name = ctx->FwOp().op_name() + "_src_grad";
ctx->DefineOp(op_src_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("dim_gather")
.InputBind("index", ctx->FwOp().input("index", 0))
.InputBind("input", ctx->FwOp().output_grad("output", 0))
.Output("output")
.Attr("dim", ctx->FwOp().attr<int32_t>("dim"))
.Build();
});
ctx->FwOp().InputGradBind(user_op::OpArg("src", 0),
[&ctx, &op_src_grad_name]() -> const std::string& {
return ctx->GetOp(op_src_grad_name).output("output", 0);
});
const auto op_input_grad_name = ctx->FwOp().op_name() + "_input_grad";
ctx->DefineOp(op_input_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) {
return builder.OpTypeName("dim_scatter_update_scalar")
.InputBind("index", ctx->FwOp().input("index", 0))
.InputBind("input", ctx->FwOp().output_grad("output", 0))
.Output("output")
.Attr("dim", ctx->FwOp().attr<int32_t>("dim"))
.Attr("src_scalar", static_cast<float>(0.0))
.Build();
});
ctx->FwOp().InputGradBind(user_op::OpArg("input", 0),
[&ctx, &op_input_grad_name]() -> const std::string& {
return ctx->GetOp(op_input_grad_name).output("output", 0);
});
return Maybe<void>::Ok();
}
} // namespace
#define REGISTER_SCATTER_LIKE_OP(optypename) \
REGISTER_USER_OP(optypename) \
.Input("like") \
.Input("index") \
.Input("src") \
.Output("output") \
.Attr<int32_t>("dim") \
.SetTensorDescInferFn(InferTensorDesc) \
.SetInputArgModifyFn(InputArgModifierFn) \
.SetDataTypeInferFn(InferDtype) \
.SetGetSbpFn(SetSbpLike)
#define REGISTER_SCATTER_OP(optypename) \
REGISTER_USER_OP(optypename) \
.Input("input") \
.Input("index") \
.Input("src") \
.Output("output") \
.Attr<int32_t>("dim") \
.SetTensorDescInferFn(InferTensorDesc) \
.SetInputArgModifyFn(InputArgModifierFn) \
.SetDataTypeInferFn(InferDtype) \
.SetGetSbpFn(SetSbpScatter)
#define REGISTER_SCATTER_SCALAR_OP(optypename) \
REGISTER_USER_OP(optypename) \
.Input("input") \
.Input("index") \
.Attr<float>("src_scalar") \
.Output("output") \
.Attr<int32_t>("dim") \
.SetTensorDescInferFn(InferScalarTensorDesc) \
.SetInputArgModifyFn(InputScalarArgModifierFn) \
.SetDataTypeInferFn(InferScalarDtype) \
.SetGetSbpFn(SetSbpScatter)
#define REGISTER_SCATTER_GRAD(optypename) \
REGISTER_USER_OP_GRAD(optypename).SetBackwardOpConfGenFn(ScatterBackward);
#define REGISTER_SCATTER_SCALAR_GRAD(optypename) \
REGISTER_USER_OP_GRAD(optypename) \
.SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe<void> { \
const auto op_input_grad_name = ctx->FwOp().op_name() + "_input_grad"; \
ctx->DefineOp(op_input_grad_name, [&ctx](user_op::BackwardOpBuilder& builder) { \
return builder.OpTypeName("dim_scatter_update_scalar") \
.InputBind("index", ctx->FwOp().input("index", 0)) \
.InputBind("input", ctx->FwOp().output_grad("output", 0)) \
.Output("output") \
.Attr("dim", ctx->FwOp().attr<int32_t>("dim")) \
.Attr("src_scalar", static_cast<float>(0.0)) \
.Build(); \
}); \
ctx->FwOp().InputGradBind(user_op::OpArg("input", 0), \
[&ctx, &op_input_grad_name]() -> const std::string& { \
return ctx->GetOp(op_input_grad_name).output("output", 0); \
}); \
return Maybe<void>::Ok(); \
});
REGISTER_SCATTER_LIKE_OP("dim_scatter_add_like");
REGISTER_SCATTER_OP("dim_scatter_add");
REGISTER_SCATTER_OP("dim_scatter_update");
REGISTER_SCATTER_OP("dim_scatter_mul");
REGISTER_SCATTER_SCALAR_OP("dim_scatter_update_scalar");
REGISTER_SCATTER_SCALAR_OP("dim_scatter_add_scalar");
REGISTER_SCATTER_SCALAR_OP("dim_scatter_mul_scalar");
REGISTER_SCATTER_GRAD("dim_scatter_add");
REGISTER_SCATTER_GRAD("dim_scatter_update");
REGISTER_SCATTER_SCALAR_GRAD("dim_scatter_update_scalar");
} // namespace user_op
} // namespace oneflow
......@@ -370,5 +370,6 @@ from oneflow.ops.user_op_builder import api_user_op_builder as user_op_builder
from oneflow.ops.user_op_builder import (
api_user_op_module_builder as user_op_module_builder,
)
from oneflow.nn.modules.scatter import *
from . import autograd, distributed, linalg, optim, saved_model, sbp
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow.framework.tensor import Tensor
from oneflow.nn.module import Module
__all__ = ["scatter", "scatter_add"]
def scatter(input, dim, index, src):
r"""This operator writes the elements specified by `index` along with the axis
`dim` from the `src` into the `input`.
Take a 3-D blob as example, the output is specified by:
.. code-block:: python
input[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
input[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
input[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
input, index and src (if it is a Tensor) should all have the same number of dimensions.
It is also required that index.shape(d) <= src.shape(d) for all dimensions d,
and that index.shape(d) <= self.shape(d) for all dimensions d != dim.
Note that index and src do not broadcast.
Args:
input (Tensor): The input blob.
dim (int): The axis along which to index
index (Tensor): The index blob of elements to scatter.
src (Tensor or float): The source blob whose elements will be scatterd and updated to output.
Returns:
Tensor: The scatterd Tensor.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> input = flow.ones((3,5))*2
>>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32)
>>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]]))
>>> out = flow.scatter(input, 1, index, src)
>>> out
tensor([[ 0., 10., 20., 2., 2.],
[50., 60., 2., 2., 70.],
[ 2., 2., 2., 2., 2.]], dtype=oneflow.float32)
"""
assert type(src) in [
flow.Tensor,
float,
], f"type of src must be oneflow.Tensor or float, but %s givien" % type(src)
if isinstance(src, flow.Tensor):
return flow.F.dim_scatter(input, index, src, dim)
elif isinstance(src, float):
return flow.F.dim_scatter_scalar(input, index, src, dim)
def scatter_add(input, dim, index, src):
r"""This operator scatter the src with addition operation according to index along dim into the input.
Take a 3-D blob as example, the output is specified by:
.. code-block:: python
input[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
input[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
input[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
Args:
input (Tensor): The input blob.
dim (int): The axis along which to index
index (Tensor): The index blob of elements to scatter.
src (Tensor): The source blob whose elements will be scatterd and added to output.
Returns:
Tensor: The scatterd Tensor.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> input = flow.ones((3,5))*2
>>> index = flow.tensor(np.array([[0,1,2],[0,1,4]], ), dtype=flow.int32)
>>> src = flow.Tensor(np.array([[0,10,20,30,40],[50,60,70,80,90]]))
>>> out = flow.scatter_add(input, 1, index, src)
>>> out
tensor([[ 2., 12., 22., 2., 2.],
[52., 62., 2., 2., 72.],
[ 2., 2., 2., 2., 2.]], dtype=oneflow.float32)
"""
assert type(src) in [
flow.Tensor
], f"type of src must be oneflow.Tensor, but %s givien" % type(src)
return flow.F.dim_scatter_add(input, index, src, dim)
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import oneflow as flow
import oneflow.unittest
import numpy as np
from automated_test_util import *
@flow.unittest.skip_unless_1n1d()
class TestScatterOpsModule(flow.unittest.TestCase):
@autotest(n=5)
def test_scatter_random_data_at_dim_0(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
index = constant(
torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device)
)
y = torch.scatter(input, 0, index, src)
return y
@autotest(n=5)
def test_scatter_random_data_at_dim_1(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
index = constant(
torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device)
)
y = torch.scatter(input, 1, index, src)
return y
@autotest(n=5)
def test_scatter_scalar_random_data_at_dim0(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
index = constant(
torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device)
)
y = torch.scatter(input, 0, index, 3.14)
return y
@autotest(n=5)
def test_scatter_scalar_random_data_at_dim1(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
index = constant(
torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device)
)
y = torch.scatter(input, 1, index, 3.14)
return y
@autotest(n=5)
def test_scatter_add_random_data_at_dim0(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
index = constant(
torch.tensor(np.array([[1, 0], [0, 1]]), dtype=torch.int64, device=device)
)
y = torch.scatter_add(input, 0, index, src)
return y
@autotest(n=5)
def test_scatter_add_random_data_at_dim1(test_case):
device = random_device()
input = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
src = random_pytorch_tensor(ndim=2, dim0=2, dim1=2).to(device)
index = constant(
torch.tensor(np.array([[0, 1], [1, 0]]), dtype=torch.int64, device=device)
)
y = torch.scatter_add(input, 1, index, src)
return y
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册