提交 d2e4d092 编写于 作者: G guo ran 提交者: Juncheng

add constant like op, sqrt_grad, square_grad (#2578)

* constant like op

* sqrt grad

* square grad

* fix

* fix

* add test case

* format
上级 062326a2
#include "oneflow/core/job_completer/autograd.h"
namespace oneflow {
namespace {
void GenerateBackwardOpConf(
const Operator& op, std::vector<OperatorConf>* op_confs,
const std::function<LogicalBlobId*(const std::string&)>& DiffLbi4BnInOp) {
CHECK(op.op_conf().has_sqrt_conf());
if (DiffLbi4BnInOp("in") != nullptr) {
OperatorConf broadcast_div_op;
broadcast_div_op.set_name(op.op_name() + "_grad_broadcast_div");
BroadcastDivOpConf* broadcast_div_op_conf = broadcast_div_op.mutable_broadcast_div_conf();
broadcast_div_op_conf->set_a(GenLogicalBlobName(*DiffLbi4BnInOp("out")));
broadcast_div_op_conf->set_b(GenLogicalBlobName(op.BnInOp2Lbi("out")));
broadcast_div_op_conf->set_out("out");
op_confs->push_back(broadcast_div_op);
OperatorConf scalar_mul_op;
scalar_mul_op.set_name(op.op_name() + "_grad_scalar_mul");
ScalarMulOpConf* scalar_mul_op_conf = scalar_mul_op.mutable_scalar_mul_conf();
scalar_mul_op_conf->set_float_operand(0.5);
scalar_mul_op_conf->set_in(
GenLogicalBlobName(broadcast_div_op.name(), broadcast_div_op_conf->out()));
scalar_mul_op_conf->set_out("out");
op_confs->push_back(scalar_mul_op);
DiffLbi4BnInOp("in")->set_op_name(scalar_mul_op.name());
DiffLbi4BnInOp("in")->set_blob_name(scalar_mul_op_conf->out());
}
}
} // namespace
REGISTER_OP_GRAD(OperatorConf::kSqrtConf, &GenerateBackwardOpConf);
} // namespace oneflow
#include "oneflow/core/job_completer/autograd.h"
namespace oneflow {
namespace {
void GenerateBackwardOpConf(
const Operator& op, std::vector<OperatorConf>* op_confs,
const std::function<LogicalBlobId*(const std::string&)>& DiffLbi4BnInOp) {
CHECK(op.op_conf().has_square_conf());
if (DiffLbi4BnInOp("in") != nullptr) {
OperatorConf multiply_in_op;
multiply_in_op.set_name(op.op_name() + "_grad_multiply_in");
MultiplyOpConf* multiply_in_op_conf = multiply_in_op.mutable_multiply_conf();
multiply_in_op_conf->set_in_0(GenLogicalBlobName(*DiffLbi4BnInOp("out")));
multiply_in_op_conf->set_in_1(GenLogicalBlobName(op.BnInOp2Lbi("in")));
multiply_in_op_conf->set_out("out");
op_confs->push_back(multiply_in_op);
OperatorConf scalar_mul_op;
scalar_mul_op.set_name(op.op_name() + "_grad_scalar_mul");
ScalarMulOpConf* scalar_mul_op_conf = scalar_mul_op.mutable_scalar_mul_conf();
scalar_mul_op_conf->set_float_operand(2);
scalar_mul_op_conf->set_in(
GenLogicalBlobName(multiply_in_op.name(), multiply_in_op_conf->out()));
scalar_mul_op_conf->set_out("out");
op_confs->push_back(scalar_mul_op);
DiffLbi4BnInOp("in")->set_op_name(scalar_mul_op.name());
DiffLbi4BnInOp("in")->set_blob_name(scalar_mul_op_conf->out());
}
}
} // namespace
REGISTER_OP_GRAD(OperatorConf::kSquareConf, &GenerateBackwardOpConf);
} // namespace oneflow
#include "oneflow/core/kernel/kernel.h"
#include "oneflow/core/kernel/new_kernel_util.h"
namespace oneflow {
template<DeviceType device_type, typename T>
class ConstantLikeKernel final : public KernelIf<device_type> {
public:
OF_DISALLOW_COPY_AND_MOVE(ConstantLikeKernel);
ConstantLikeKernel() : is_init_(false) {}
~ConstantLikeKernel() = default;
private:
mutable bool is_init_;
void ForwardDataContent(const KernelCtx& ctx,
std::function<Blob*(const std::string&)> BnInOp2Blob) const override {
if (is_init_) { return; }
Blob* out_blob = BnInOp2Blob("out");
T value = 0;
const auto& conf = this->op_conf().constant_like_conf();
if (conf.has_int_operand()) {
value = static_cast<T>(conf.int_operand());
} else if (conf.has_float_operand()) {
value = static_cast<T>(conf.float_operand());
} else {
UNIMPLEMENTED();
}
NewKernelUtil<device_type>::Fill(ctx.device_ctx, out_blob->static_shape().elem_cnt(), value,
out_blob->mut_dptr<T>());
is_init_ = true;
}
};
#define REGISTER_CONSTANT_LIKE_KERNEL(dtype) \
REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(OperatorConf::kConstantLikeConf, DeviceType::kCPU, dtype, \
ConstantLikeKernel<DeviceType::kCPU, dtype>) \
REGISTER_KERNEL_WITH_DEVICE_AND_DTYPE(OperatorConf::kConstantLikeConf, DeviceType::kGPU, dtype, \
ConstantLikeKernel<DeviceType::kGPU, dtype>)
REGISTER_CONSTANT_LIKE_KERNEL(float);
REGISTER_CONSTANT_LIKE_KERNEL(double);
REGISTER_CONSTANT_LIKE_KERNEL(int8_t);
REGISTER_CONSTANT_LIKE_KERNEL(int32_t);
REGISTER_CONSTANT_LIKE_KERNEL(int64_t);
#undef REGISTER_CONSTANT_LIKE_KERNEL
} // namespace oneflow
......@@ -133,6 +133,11 @@ __global__ void MulByScalarGpu<half>(const int64_t n, const half* x, const half
#endif // __CUDA_ARCH__ >= 530 || !defined(__CUDA_ARCH__)
}
template<typename T>
__global__ void FillGpu(const int64_t n, const T value, T* y) {
CUDA_1D_KERNEL_LOOP(i, n) { y[i] = value; }
}
} // namespace
#define MUL_BY_SCALAR(T) \
......@@ -155,4 +160,19 @@ void ArithemeticIf<DeviceType::kGPU>::MulByScalar(DeviceCtx* ctx, const int64_t
n, reinterpret_cast<const half*>(x), float16_2half(y), reinterpret_cast<half*>(z));
}
#define FILL(T) \
void ArithemeticIf<DeviceType::kGPU>::Fill(DeviceCtx* ctx, const int64_t n, const T value, \
T* y) { \
FillGpu<T><<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0, ctx->cuda_stream()>>>( \
n, value, y); \
}
FILL(float)
FILL(double)
FILL(int8_t)
FILL(int32_t)
FILL(int64_t)
#undef FILL
} // namespace oneflow
......@@ -34,6 +34,12 @@ struct ArithemeticIf<DeviceType::kGPU> {
int32_t* z);
static void MulByScalar(DeviceCtx* ctx, const int64_t n, const int64_t* x, const int64_t y,
int64_t* z);
static void Fill(DeviceCtx* ctx, const int64_t n, const float value, float* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const double value, double* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const int8_t value, int8_t* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const int32_t value, int32_t* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const int64_t value, int64_t* y);
};
} // namespace oneflow
......
......@@ -112,4 +112,18 @@ MUL_BY_SCALAR(int64_t);
#undef MUL_BY_SCALAR
#define FILL(T) \
void ArithemeticIf<DeviceType::kCPU>::Fill(DeviceCtx* ctx, const int64_t n, const T value, \
T* y) { \
std::fill_n(y, n, value); \
}
FILL(float);
FILL(double);
FILL(int8_t);
FILL(int32_t);
FILL(int64_t);
#undef FILL
} // namespace oneflow
......@@ -29,6 +29,12 @@ struct ArithemeticIf<DeviceType::kCPU> {
int32_t* z);
static void MulByScalar(DeviceCtx* ctx, const int64_t n, const int64_t* x, const int64_t y,
int64_t* z);
static void Fill(DeviceCtx* ctx, const int64_t n, const float value, float* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const double value, double* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const int8_t value, int8_t* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const int32_t value, int32_t* y);
static void Fill(DeviceCtx* ctx, const int64_t n, const int64_t value, int64_t* y);
};
} // namespace oneflow
......
#include "oneflow/core/operator/operator.h"
namespace oneflow {
class ConstantLikeOp final : public Operator {
public:
OF_DISALLOW_COPY_AND_MOVE(ConstantLikeOp);
ConstantLikeOp() = default;
~ConstantLikeOp() = default;
void InitFromOpConf() override {
CHECK(op_conf().has_constant_like_conf());
EnrollInputBn("like", false);
EnrollOutputBn("out", false);
}
const PbMessage& GetCustomizedConf() const override {
return this->op_conf().constant_like_conf();
}
Maybe<void> InferBlobDescs(std::function<BlobDesc*(const std::string&)> GetBlobDesc4BnInOp,
const ParallelContext* parallel_ctx, const SbpSignature* sbp_signature,
std::function<void(OpContext*)> EnrollOpCtx) const override {
const ConstantLikeOpConf& conf = op_conf().constant_like_conf();
BlobDesc* out_blob_desc = GetBlobDesc4BnInOp("out");
*out_blob_desc = *GetBlobDesc4BnInOp("like");
if (conf.has_data_type()) { out_blob_desc->set_data_type(conf.data_type()); }
return Maybe<void>::Ok();
}
private:
Maybe<void> InferBatchAxis(
std::function<OptInt64*(const std::string&)> BatchAxis4BnInOp) const override {
return NaiveInferBatchAxis(BatchAxis4BnInOp);
}
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc*>(const std::string&)>& LogicalBlobDesc4Ibn,
SbpSignatureList* sbp_sig_list) const {
SbpSignatureBuilder()
.Split("like", 0)
.Split("out", 0)
.MakeSplitSignatureListBuilder(JUST(LogicalBlobDesc4Ibn("like"))->shape().NumAxes())
.Build(sbp_sig_list);
SbpSignatureBuilder().PartialSum("like").Broadcast("out").Build(
sbp_sig_list->mutable_sbp_signature()->Add());
return Maybe<void>::Ok();
}
};
REGISTER_OP(OperatorConf::kConstantLikeConf, ConstantLikeOp);
REGISTER_OP_SAME_OUTPUT_BLOB_REGST_NUM(OperatorConf::kConstantLikeConf, 1);
} // namespace oneflow
......@@ -1623,6 +1623,16 @@ message IndexedSlicesReduceSumOpConf {
required string num_unique = 5;
}
message ConstantLikeOpConf {
required string like = 1;
required string out = 2;
optional DataType data_type = 3;
oneof scalar_operand {
int64 int_operand = 4;
double float_operand = 5;
}
}
message OperatorConf {
required string name = 1;
optional bool trainable = 3 [default = true];
......@@ -1798,6 +1808,7 @@ message OperatorConf {
RsqrtOpConf rsqrt_conf = 515;
ScalarAddOpConf scalar_add_conf = 516;
ScalarMulOpConf scalar_mul_conf = 517;
ConstantLikeOpConf constant_like_conf = 518;
// mutable input op
AxpyOpConf axpy_conf = 752;
......
......@@ -48,18 +48,33 @@ def constant_scalar(value, dtype=None, name=None):
@oneflow_export("constant_like")
def constant_like(input, value, name=None):
def constant_like(like, value, dtype=None, name=None):
op_conf = op_conf_util.OperatorConf()
setattr(
op_conf,
"name",
name if name is not None else id_util.UniqueStr("ConstantLike_"),
)
setattr(op_conf.constant_like_conf, "in", input.logical_blob_name)
setattr(op_conf.constant_like_conf, "scalar", value)
setattr(op_conf.constant_like_conf, "like", like.logical_blob_name)
if isinstance(value, int):
op_conf.constant_like_conf.int_operand = value
elif isinstance(value, float):
op_conf.constant_like_conf.float_operand = value
else:
raise NotImplementedError
if dtype is not None:
setattr(op_conf.constant_like_conf, "data_type", dtype)
setattr(op_conf.constant_like_conf, "out", "out")
compile_context.CurJobAddOp(op_conf)
out_lbi = logical_blob_id_util.LogicalBlobId()
setattr(out_lbi, "op_name", op_conf.name)
setattr(out_lbi, "blob_name", "out")
return remote_blob_util.RemoteBlob(out_lbi)
@oneflow_export("ones_like")
def ones_like(like, dtype=None, name=None):
return constant_like(like, 1, dtype=dtype, name=name)
@oneflow_export("zeros_like")
def zeros_like(like, dtype=None, name=None):
return constant_like(like, 0, dtype=dtype, name=name)
......@@ -336,9 +336,8 @@ def unsorted_batch_segment_sum(data, segment_ids, num_segments, name=None):
return remote_blob_util.RemoteBlob(lbi)
@oneflow_export("math.sqrt")
def sqrt(x, name=None):
# TODO: not ready yet
raise NotImplementedError
op_conf = op_conf_util.OperatorConf()
setattr(op_conf, "name", name if name is not None else id_util.UniqueStr("Sqrt_"))
setattr(op_conf.sqrt_conf, "in", x.logical_blob_name)
......@@ -637,3 +636,19 @@ def elem_cnt(input_blob, axis=None, dtype=None, name=None):
out_lbi.op_name = op_conf.name
out_lbi.blob_name = "y"
return remote_blob_util.RemoteBlob(out_lbi)
@oneflow_export('math.square')
def square(x, name=None):
op_conf = op_conf_util.OperatorConf()
setattr(
op_conf,
"name",
name if name is not None else id_util.UniqueStr("Square_"),
)
setattr(op_conf.square_conf, "in", x.logical_blob_name)
setattr(op_conf.square_conf, "out", "out")
compile_context.CurJobAddOp(op_conf)
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_conf.name
lbi.blob_name = "out"
return remote_blob_util.RemoteBlob(lbi)
import oneflow as flow
import numpy as np
def _check(test_case, x, y, value, dtype=None):
np_constant_like = np.full(x.shape, value)
test_case.assertTrue(np.array_equal(np_constant_like, y))
def _run_test(test_case, x, value, dtype=None, device='gpu'):
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.default_distribute_strategy(flow.distribute.consistent_strategy())
@flow.function(func_config)
def ConstantLikeJob(x=flow.FixedTensorDef(x.shape)):
return flow.constant_like(x, value=value, dtype=dtype)
y = ConstantLikeJob(x).get()
_check(test_case, x, y.ndarray(), value, dtype=dtype)
def test_constant_like_gpu_float(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 1.0, flow.float, 'gpu')
def test_constant_like_cpu_float(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 2.0, flow.float, 'cpu')
def test_constant_like_gpu_double(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 3.0, flow.double, 'gpu')
def test_constant_like_cpu_double(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 4.0, flow.double, 'cpu')
def test_constant_like_gpu_int8(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 5.0, flow.int8, 'gpu')
def test_constant_like_cpu_int8(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 6.0, flow.int8, 'cpu')
def test_constant_like_gpu_int32(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 7.0, flow.int32, 'gpu')
def test_constant_like_cpu_int32(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 8.0, flow.int32, 'cpu')
def test_constant_like_gpu_int64(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 9.0, flow.int64, 'gpu')
def test_constant_like_cpu_int64(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 10.0, flow.int64, 'cpu')
def test_constant_like_gpu(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 11.0, device='gpu')
def test_constant_like_cpu(test_case):
x = np.random.rand(10, 3, 32, 1024).astype(np.float32)
_run_test(test_case, x, 12.0, device='cpu')
import os
import numpy as np
import tensorflow as tf
import oneflow as flow
from collections import OrderedDict
from test_util import GenArgList
from test_util import GetSavePath
from test_util import Save
def compare_with_tensorflow(device_type, x_shape):
assert device_type in ["gpu", "cpu"]
flow.clear_default_session()
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.train.primary_lr(1e-4)
func_config.train.model_update_conf(dict(naive_conf={}))
@flow.function(func_config)
def SqrtJob():
with flow.device_prior_placement(device_type, "0:0"):
x = flow.get_variable(
"x",
shape=x_shape,
dtype=flow.float,
initializer=flow.random_uniform_initializer(minval=0, maxval=100),
trainable=True,
)
loss = flow.math.sqrt(x)
flow.losses.add_loss(loss)
flow.watch(x, Save("x"))
flow.watch_diff(x, Save("x_diff"))
flow.watch(loss, Save("loss"))
flow.watch_diff(loss, Save("loss_diff"))
return loss
# OneFlow
check_point = flow.train.CheckPoint()
check_point.init()
of_out = SqrtJob().get()
# TensorFlow
with tf.GradientTape(persistent=True) as tape:
x = tf.Variable(np.load(os.path.join(GetSavePath(), "x.npy")))
tf_out = tf.math.sqrt(x)
loss_diff = np.load(os.path.join(GetSavePath(), "loss_diff.npy"))
tf_x_diff = tape.gradient(tf_out, x, loss_diff)
assert np.allclose(of_out.ndarray(), tf_out.numpy(), rtol=1e-5, atol=1e-5)
assert np.allclose(
np.load(os.path.join(GetSavePath(), "x_diff.npy")), tf_x_diff.numpy(), rtol=1e-5, atol=1e-5
)
def test_sqrt(test_case):
arg_dict = OrderedDict()
arg_dict["device_type"] = ["gpu"]
arg_dict["x_shape"] = [(10, 20, 30)]
for arg in GenArgList(arg_dict):
compare_with_tensorflow(*arg)
import os
import numpy as np
import tensorflow as tf
import oneflow as flow
from collections import OrderedDict
from test_util import GenArgList
from test_util import GetSavePath
from test_util import Save
def compare_with_tensorflow(device_type, x_shape):
assert device_type in ["gpu", "cpu"]
flow.clear_default_session()
func_config = flow.FunctionConfig()
func_config.default_data_type(flow.float)
func_config.train.primary_lr(1e-4)
func_config.train.model_update_conf(dict(naive_conf={}))
@flow.function(func_config)
def SquareJob():
with flow.device_prior_placement(device_type, "0:0"):
x = flow.get_variable(
"x",
shape=x_shape,
dtype=flow.float,
initializer=flow.random_uniform_initializer(minval=-10, maxval=10),
trainable=True,
)
loss = flow.math.square(x)
flow.losses.add_loss(loss)
flow.watch(x, Save("x"))
flow.watch_diff(x, Save("x_diff"))
flow.watch(loss, Save("loss"))
flow.watch_diff(loss, Save("loss_diff"))
return loss
# OneFlow
check_point = flow.train.CheckPoint()
check_point.init()
of_out = SquareJob().get()
# TensorFlow
with tf.GradientTape(persistent=True) as tape:
x = tf.Variable(np.load(os.path.join(GetSavePath(), "x.npy")))
tf_out = tf.math.square(x)
loss_diff = np.load(os.path.join(GetSavePath(), "loss_diff.npy"))
tf_x_diff = tape.gradient(tf_out, x, loss_diff)
assert np.allclose(of_out.ndarray(), tf_out.numpy(), rtol=1e-5, atol=1e-5)
assert np.allclose(
np.load(os.path.join(GetSavePath(), "x_diff.npy")), tf_x_diff.numpy(), rtol=1e-5, atol=1e-5
)
def test_square(test_case):
arg_dict = OrderedDict()
arg_dict["device_type"] = ["gpu"]
arg_dict["x_shape"] = [(10, 20, 30)]
for arg in GenArgList(arg_dict):
compare_with_tensorflow(*arg)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册