提交 be606ba8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5432 Mindspore parallel supports all elementary-wise operators

Merge pull request !5432 from yihuaijie/master
......@@ -90,6 +90,7 @@ REGISTER(TensorAddInfo);
REGISTER(BiasAddInfo);
REGISTER(MulInfo);
REGISTER(DivInfo);
REGISTER(ModInfo);
REGISTER(RealDivInfo);
REGISTER(PowInfo);
REGISTER(ExpInfo);
......@@ -117,15 +118,56 @@ REGISTER(MaximumInfo);
REGISTER(MinimumInfo);
REGISTER(CastInfo);
REGISTER(GreaterInfo);
REGISTER(GreaterEqualInfo);
REGISTER(LessEqualInfo);
REGISTER(LessInfo);
REGISTER(ApproximateEqualInfo);
REGISTER(SparseSoftmaxCrossEntropyWithLogitsInfo);
REGISTER(AssignSubInfo);
REGISTER(FloorModInfo);
REGISTER(AssignInfo);
REGISTER(AssignAddInfo);
REGISTER(Atan2Info);
REGISTER(DivNoNanInfo);
REGISTER(LogicalAndInfo);
REGISTER(LogicalOrInfo);
REGISTER(EluInfo);
REGISTER(ReLUInfo);
REGISTER(ReLU6Info);
REGISTER(ReLUV2Info);
REGISTER(SoftplusInfo);
REGISTER(SoftsignInfo);
REGISTER(GatherV2Info);
REGISTER(SparseGatherV2Info);
REGISTER(SqrtInfo);
REGISTER(SigmoidInfo);
REGISTER(GetNextInfo);
REGISTER(NegInfo);
REGISTER(AbsInfo);
REGISTER(AcoshInfo);
REGISTER(AsinInfo);
REGISTER(AsinhInfo);
REGISTER(AtanInfo);
REGISTER(AtanhInfo);
REGISTER(CeilInfo);
REGISTER(CoshInfo);
REGISTER(Expm1Info);
REGISTER(Log1pInfo);
REGISTER(SinInfo);
REGISTER(SinhInfo);
REGISTER(TanInfo);
REGISTER(RsqrtInfo);
REGISTER(InvInfo);
REGISTER(ReciprocalInfo);
REGISTER(RoundInfo);
REGISTER(FloorInfo);
REGISTER(SignInfo);
REGISTER(ErfInfo);
REGISTER(ErfcInfo);
REGISTER(ZerosLikeInfo);
REGISTER(OnesLikeInfo);
REGISTER(BesselI0eInfo);
REGISTER(BesselI1eInfo);
REGISTER(BatchMatMulInfo);
REGISTER(ExpandDimsInfo);
REGISTER(SqueezeInfo);
......
......@@ -131,6 +131,13 @@ class LogSoftmaxInfo : public Softmax {
~LogSoftmaxInfo() override = default;
};
class EluInfo : public ActivationOther {
public:
EluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~EluInfo() override = default;
};
class ReLUInfo : public ActivationOther {
public:
ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
......@@ -139,6 +146,38 @@ class ReLUInfo : public ActivationOther {
~ReLUInfo() override = default;
};
class ReLU6Info : public ActivationOther {
public:
ReLU6Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ReLU6Info() override = default;
};
class ReLUV2Info : public ActivationOther {
public:
ReLUV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ReLUV2Info() override = default;
};
class SoftsignInfo : public ActivationOther {
public:
SoftsignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SoftsignInfo() override = default;
};
class SoftplusInfo : public ActivationOther {
public:
SoftplusInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SoftplusInfo() override = default;
};
class CastInfo : public ActivationOther {
public:
CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
......
......@@ -82,6 +82,13 @@ class DivInfo : public ArithmeticBase {
~DivInfo() override = default;
};
class ModInfo : public ArithmeticBase {
public:
ModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~ModInfo() override = default;
};
class RealDivInfo : public ArithmeticBase {
public:
RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
......@@ -98,6 +105,14 @@ class FloorDivInfo : public ArithmeticBase {
~FloorDivInfo() override = default;
};
class FloorModInfo : public ArithmeticBase {
public:
FloorModInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~FloorModInfo() override = default;
};
class PowInfo : public ArithmeticBase {
public:
PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
......@@ -105,20 +120,28 @@ class PowInfo : public ArithmeticBase {
~PowInfo() override = default;
};
class GreaterInfo : public ArithmeticBase {
class AssignSubInfo : public ArithmeticBase {
public:
AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~AssignSubInfo() override = default;
};
class AssignInfo : public ArithmeticBase {
public:
GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
AssignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterInfo() override = default;
~AssignInfo() override = default;
};
class AssignSubInfo : public ArithmeticBase {
class AssignAddInfo : public ArithmeticBase {
public:
AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
AssignAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~AssignSubInfo() override = default;
~AssignAddInfo() override = default;
};
// All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label.
......@@ -129,6 +152,38 @@ class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase {
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~SigmoidCrossEntropyWithLogitsInfo() override = default;
};
class Atan2Info : public ArithmeticBase {
public:
Atan2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~Atan2Info() override = default;
};
class DivNoNanInfo : public ArithmeticBase {
public:
DivNoNanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~DivNoNanInfo() override = default;
};
class LogicalAndInfo : public ArithmeticBase {
public:
LogicalAndInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LogicalAndInfo() override = default;
};
class LogicalOrInfo : public ArithmeticBase {
public:
LogicalOrInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LogicalOrInfo() override = default;
};
} // namespace parallel
} // namespace mindspore
......
......@@ -36,6 +36,14 @@ class EqualInfo : public ArithmeticBase {
~EqualInfo() override = default;
};
class ApproximateEqualInfo : public ArithmeticBase {
public:
ApproximateEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~ApproximateEqualInfo() override = default;
};
class NotEqualInfo : public ArithmeticBase {
public:
NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
......@@ -59,6 +67,38 @@ class MinimumInfo : public ArithmeticBase {
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(true)) {}
~MinimumInfo() override = default;
};
class GreaterInfo : public ArithmeticBase {
public:
GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterInfo() override = default;
};
class GreaterEqualInfo : public ArithmeticBase {
public:
GreaterEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~GreaterEqualInfo() override = default;
};
class LessInfo : public ArithmeticBase {
public:
LessInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LessInfo() override = default;
};
class LessEqualInfo : public ArithmeticBase {
public:
LessEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared<ArithmeticCost>(false)) {}
~LessEqualInfo() override = default;
};
} // namespace parallel
} // namespace mindspore
......
......@@ -63,6 +63,202 @@ class LogicalNotInfo : public ActivationOther {
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~LogicalNotInfo() override = default;
};
class AbsInfo : public ActivationOther {
public:
AbsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AbsInfo() override = default;
};
class SignInfo : public ActivationOther {
public:
SignInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SignInfo() override = default;
};
class FloorInfo : public ActivationOther {
public:
FloorInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~FloorInfo() override = default;
};
class RoundInfo : public ActivationOther {
public:
RoundInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~RoundInfo() override = default;
};
class ReciprocalInfo : public ActivationOther {
public:
ReciprocalInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ReciprocalInfo() override = default;
};
class InvInfo : public ActivationOther {
public:
InvInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~InvInfo() override = default;
};
class RsqrtInfo : public ActivationOther {
public:
RsqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~RsqrtInfo() override = default;
};
class TanInfo : public ActivationOther {
public:
TanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~TanInfo() override = default;
};
class SinInfo : public ActivationOther {
public:
SinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SinInfo() override = default;
};
class SinhInfo : public ActivationOther {
public:
SinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~SinhInfo() override = default;
};
class Log1pInfo : public ActivationOther {
public:
Log1pInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~Log1pInfo() override = default;
};
class Expm1Info : public ActivationOther {
public:
Expm1Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~Expm1Info() override = default;
};
class CoshInfo : public ActivationOther {
public:
CoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~CoshInfo() override = default;
};
class CeilInfo : public ActivationOther {
public:
CeilInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~CeilInfo() override = default;
};
class AtanhInfo : public ActivationOther {
public:
AtanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AtanhInfo() override = default;
};
class AtanInfo : public ActivationOther {
public:
AtanInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AtanInfo() override = default;
};
class AsinInfo : public ActivationOther {
public:
AsinInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AsinInfo() override = default;
};
class AsinhInfo : public ActivationOther {
public:
AsinhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AsinhInfo() override = default;
};
class AcoshInfo : public ActivationOther {
public:
AcoshInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~AcoshInfo() override = default;
};
class ErfInfo : public ActivationOther {
public:
ErfInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ErfInfo() override = default;
};
class ErfcInfo : public ActivationOther {
public:
ErfcInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ErfcInfo() override = default;
};
class ZerosLikeInfo : public ActivationOther {
public:
ZerosLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~ZerosLikeInfo() override = default;
};
class OnesLikeInfo : public ActivationOther {
public:
OnesLikeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~OnesLikeInfo() override = default;
};
class BesselI0eInfo : public ActivationOther {
public:
BesselI0eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~BesselI0eInfo() override = default;
};
class BesselI1eInfo : public ActivationOther {
public:
BesselI1eInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~BesselI1eInfo() override = default;
};
} // namespace parallel
} // namespace mindspore
......
......@@ -122,17 +122,39 @@ def test_matmul_mul():
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_mod():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.mod = P.Mod().set_strategy(strategy2)
def test_matmul_div():
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.mod(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_floormod():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.div = P.Div().set_strategy(strategy2)
self.floormod = P.FloorMod().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.div(out, b)
out = self.floormod(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
......@@ -147,16 +169,122 @@ def test_matmul_div():
compile_net(net, x, y, b)
def test_matmul_greater():
def test_matmul_atan2():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.greater = P.Greater().set_strategy(strategy2)
self.atan2 = P.Atan2().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.greater(out, b)
out = self.atan2(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_divNoNan():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.divNoNan = P.DivNoNan().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.divNoNan(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_logicaland():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.equal = P.Equal().set_strategy(strategy2)
self.notequal = P.NotEqual().set_strategy(strategy2)
self.logical = P.LogicalAnd().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out1 = self.equal(out, b)
out = self.matmul(x, y)
out2 = self.notequal(out, b)
out = self.logical(out1, out2)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_logicalor():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.equal = P.Equal().set_strategy(strategy2)
self.notequal = P.NotEqual().set_strategy(strategy2)
self.logical = P.LogicalOr().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out1 = self.equal(out, b)
out = self.matmul(x, y)
out2 = self.notequal(out, b)
out = self.logical(out1, out2)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_div():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.div = P.Div().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.div(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
......@@ -528,3 +656,97 @@ def test_assign_sub():
net = SubGradWrap(SubNetWithLoss(Net()))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
compile_sub_net(net, x)
def test_assign_add():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.AssignAdd()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")
def construct(self, x):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
return out
class SubNetWithLoss(nn.Cell):
def __init__(self, network):
super(SubNetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x):
predict = self.network(x,)
return self.loss(predict)
class SubGradWrap(nn.Cell):
def __init__(self, network):
super(SubGradWrap, self).__init__()
self.network = network
def construct(self, x):
return grad_all(self.network)(x)
def compile_sub_net(net, x):
net.set_auto_parallel()
_executor.compile(net, x)
context.set_auto_parallel_context(device_num=64, global_rank=15)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = SubGradWrap(SubNetWithLoss(Net()))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
compile_sub_net(net, x)
def test_assign():
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.assign_sub = P.Assign()
self.mul = P.Mul()
self.mul_weight = Parameter(Tensor(np.full([128, 32],
0.5, dtype=np.float32)),
name="mul_weight")
self.assignsub_weight = Parameter(Tensor(np.full([128, 32],
1.1, dtype=np.float32)),
name="assignsub_weight")
def construct(self, x):
out = self.mul(x, self.mul_weight)
out = self.assign_sub(self.assignsub_weight, out)
return out
class SubNetWithLoss(nn.Cell):
def __init__(self, network):
super(SubNetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x):
predict = self.network(x,)
return self.loss(predict)
class SubGradWrap(nn.Cell):
def __init__(self, network):
super(SubGradWrap, self).__init__()
self.network = network
def construct(self, x):
return grad_all(self.network)(x)
def compile_sub_net(net, x):
net.set_auto_parallel()
_executor.compile(net, x)
context.set_auto_parallel_context(device_num=64, global_rank=15)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net = SubGradWrap(SubNetWithLoss(Net()))
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
compile_sub_net(net, x)
......@@ -98,6 +98,126 @@ def test_matmul_not_equal():
compile_net(net, x, y, b)
def test_matmul_approximateEqual():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.approximateEqual = P.ApproximateEqual(tolerance=0.5).set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.approximateEqual(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_greater():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.greater = P.Greater().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.greater(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_greaterEqual():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.greaterEqual = P.GreaterEqual().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.greaterEqual(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_less():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.less = P.Less().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.less(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_lessEqual():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.matmul = P.MatMul().set_strategy(strategy1)
self.lessEqual = P.LessEqual().set_strategy(strategy2)
def construct(self, x, y, b):
out = self.matmul(x, y)
out = self.lessEqual(out, b)
return out
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), (4, 2))
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
compile_net(net, x, y, b)
def test_matmul_not_equal_repeated_calculation():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册