提交 5d0801ed 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!188 Support pow's second input could be tensor and fix bug in bprop of pow

Merge pull request !188 from zhangbuxue/fix_pow_bprop
...@@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase { ...@@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
~FloorDivInfo() override = default; ~FloorDivInfo() override = default;
}; };
class PowInfo : public ArithmeticBase {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
};
class GreaterInfo : public ArithmeticBase { class GreaterInfo : public ArithmeticBase {
public: public:
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/elementary_function_info.h"
namespace mindspore {
namespace parallel {
Status PowInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector mirror_op;
OperatorVector op_for_value;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
mirror_ops_.push_back(op_for_value);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
...@@ -27,16 +27,6 @@ ...@@ -27,16 +27,6 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
class PowInfo : public ActivationOther {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
protected:
Status InferMirrorOps() override;
};
class ExpInfo : public ActivationOther { class ExpInfo : public ActivationOther {
public: public:
ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
......
...@@ -58,7 +58,7 @@ class _PoolNd(Cell): ...@@ -58,7 +58,7 @@ class _PoolNd(Cell):
pass pass
def extend_repr(self): def extend_repr(self):
return 'kernel_size={kernel_size}, strides={strides}, pad_mode={pad_mode}'.format(**self.__dict__) return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
class MaxPool2d(_PoolNd): class MaxPool2d(_PoolNd):
......
...@@ -336,14 +336,13 @@ def get_bprop_log(self): ...@@ -336,14 +336,13 @@ def get_bprop_log(self):
@bprop_getters.register(P.Pow) @bprop_getters.register(P.Pow)
def get_bprop_pow(self): def get_bprop_pow(self):
"""Grad definition for `Pow` operation.""" """Grad definition for `Pow` operation."""
pow_ = P.Pow() pow_op = P.Pow()
cast = P.Cast() ln = P.Log()
dtype = P.DType()
def bprop(x, power, out, dout): def bprop(x, power, out, dout):
g = cast(F.tuple_to_array((power,)), dtype(x)) * pow_(x, power-1.0) dx = power * pow_op(x, power - 1.0) * dout
dx = g * dout dpower = pow_op(x, power) * ln(x) * dout
return dx, 0 return dx, dpower
return bprop return bprop
......
...@@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer): ...@@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
...@@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer): ...@@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis = self.axis axis = self.axis
x_rank = len(x_shape) x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT) validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name()) ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
......
...@@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer): ...@@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output']) self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape): def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape, self.prim_name()) return _get_broadcast_shape(x_shape, y_shape, self.name)
class _MathBinaryOp(_BinaryOp): class _MathBinaryOp(_BinaryOp):
...@@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp): ...@@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp):
return x_dtype return x_dtype
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name()) return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.name)
class TensorAdd(_MathBinaryOp): class TensorAdd(_MathBinaryOp):
...@@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer): ...@@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value return value
...@@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer): ...@@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer):
def infer_dtype(self, variable, value): def infer_dtype(self, variable, value):
args = {"value": value} args = {"value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name()) validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value return value
...@@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer): ...@@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, keep_dims=False): def __init__(self, keep_dims=False):
"""init Reduce""" """init Reduce"""
validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name()) validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
axis_v = axis['value'] axis_v = axis['value']
input_shp = input_x['shape'] input_shp = input_x['shape']
args = {'input_x': input_x['dtype']} args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.prim_name()) validator.check_tensor_type_same(args, valid_dtype, self.name)
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name()) input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
return {'shape': input_shp, return {'shape': input_shp,
'dtype': input_x['dtype'], 'dtype': input_x['dtype'],
'value': None} 'value': None}
...@@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer): ...@@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer):
""" """
@prim_attr_register @prim_attr_register
def __init__(self, exclusive=False, reverse=False): def __init__(self, exclusive=False, reverse=False):
cls_name = self.prim_name() cls_name = self.name
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name) self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name) self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
...@@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer): ...@@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type, axis_type): def infer_dtype(self, x_type, axis_type):
cls_name = self.prim_name() cls_name = self.name
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
validator.check_subclass("axis", axis_type, mstype.int_, cls_name) validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type return x_type
...@@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer): ...@@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer):
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True self.__setattr_flag__ = True
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
...@@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer): ...@@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer):
def infer_shape(self, x, y): def infer_shape(self, x, y):
self.check_shape_size(x, y) self.check_shape_size(x, y)
cls_name = self.prim_name() cls_name = self.name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two # expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for i in range(len(x) - 2): for i in range(len(x) - 2):
if x[i] != y[i]: if x[i] != y[i]:
...@@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer): ...@@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer):
def infer_dtype(self, x, y): def infer_dtype(self, x, y):
args = {"x": x, "y": y} args = {"x": x, "y": y}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name()) validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
return x return x
...@@ -590,7 +590,7 @@ class BatchMatMul(MatMul): ...@@ -590,7 +590,7 @@ class BatchMatMul(MatMul):
def __init__(self, transpose_a=False, transpose_b=False): def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True self.__setattr_flag__ = True
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
...@@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer): ...@@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, exclusive=False, reverse=False): def __init__(self, exclusive=False, reverse=False):
"""init cumsum""" """init cumsum"""
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type('exclusive', exclusive, [bool], cls_name) validator.check_value_type('exclusive', exclusive, [bool], cls_name)
validator.check_value_type('reverse', reverse, [bool], cls_name) validator.check_value_type('reverse', reverse, [bool], cls_name)
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y']) self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
def __infer__(self, x, axis): def __infer__(self, x, axis):
cls_name = self.prim_name() cls_name = self.name
x_shp = x['shape'] x_shp = x['shape']
validator.check_value_type('axis', axis['value'], [int], cls_name) validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32] valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
...@@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer): ...@@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer):
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def infer_shape(self, inputs): def infer_shape(self, inputs):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
self.add_prim_attr('n', len(inputs)) self.add_prim_attr('n', len(inputs))
shp0 = inputs[0] shp0 = inputs[0]
...@@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer): ...@@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer):
return shp0 return shp0
def infer_dtype(self, inputs): def infer_dtype(self, inputs):
cls_name = self.prim_name() cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name) validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name) validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
args = {} args = {}
...@@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer): ...@@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer):
return input_x return input_x
def infer_dtype(self, input_x): def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x return input_x
...@@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer): ...@@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type return x_type
...@@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer): ...@@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type return x_type
...@@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer): ...@@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type return x_type
...@@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer): ...@@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.prim_name()) validator.check_subclass("x", x, mstype.tensor, self.name)
return x return x
class Pow(PrimitiveWithInfer): class Pow(_MathBinaryOp):
""" """
Computes a tensor to the power of the second input. Computes a tensor to the power of the second input.
The first input must be a tensor, and the second input should be a tensor or a number.
When the inputs are two tensors, the shapes of them could be broadcast,
and the data types of them should be the same.
When the inputs are one tensor and one scalar, the scalar could not be a parameter,
only could be a constant, and the type of the scalar is the same as the data type of the tensor.
Inputs:
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
- **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or
a number.
Outputs:
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
Inputs: Inputs:
- **input_x** (Tensor) - The input tensor. - **input_x** (Tensor) - The input tensor.
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to - **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to
...@@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer): ...@@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer):
[1.0, 16.0, 64.0] [1.0, 16.0, 64.0]
""" """
@prim_attr_register
def __init__(self):
"""init Multiply"""
def infer_shape(self, x, power):
return x
def infer_dtype(self, x, power):
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name())
return x
class Exp(PrimitiveWithInfer): class Exp(PrimitiveWithInfer):
""" """
...@@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer): ...@@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.prim_name()) validator.check_subclass("x", x_type, mstype.tensor, self.name)
return x_type return x_type
...@@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer): ...@@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.prim_name()) validator.check_subclass("x", x, mstype.tensor, self.name)
return x return x
...@@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer): ...@@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name()) validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name)
return x_dtype return x_dtype
...@@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer): ...@@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
...@@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp): ...@@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp):
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name)
class Equal(_LogicBinaryOp): class Equal(_LogicBinaryOp):
...@@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp): ...@@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
class EqualCount(PrimitiveWithInfer): class EqualCount(PrimitiveWithInfer):
...@@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer): ...@@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype} args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name()) validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype return x_dtype
...@@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp): ...@@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
class Greater(_LogicBinaryOp): class Greater(_LogicBinaryOp):
...@@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer): ...@@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name()) validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name)
return mstype.tensor_type(mstype.bool_) return mstype.tensor_type(mstype.bool_)
...@@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp): ...@@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
class LogicalOr(_LogicBinaryOp): class LogicalOr(_LogicBinaryOp):
...@@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp): ...@@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp):
""" """
def infer_dtype(self, x_dtype, y_dtype): def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name()) return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
class IsNan(PrimitiveWithInfer): class IsNan(PrimitiveWithInfer):
""" """
...@@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer): ...@@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True) self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
return mstype.float32 return mstype.float32
...@@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer): ...@@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True) self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name) validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name) validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8] return [8]
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name()) validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
return mstype.float32 return mstype.float32
...@@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer): ...@@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
...@@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer): ...@@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
...@@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer): ...@@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer):
return x return x
def infer_dtype(self, x): def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x return x
...@@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer): ...@@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, iou_threshold=0.5): def __init__(self, iou_threshold=0.5):
"""Init NMSWithMask""" """Init NMSWithMask"""
validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name()) validator.check_value_type("iou_threshold", iou_threshold, [float], self.name)
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask']) self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
def infer_shape(self, bboxes_shape): def infer_shape(self, bboxes_shape):
cls_name = self.prim_name() cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
...@@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer): ...@@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer):
return (bboxes_shape, (num,), (num,)) return (bboxes_shape, (num,), (num,))
def infer_dtype(self, bboxes_dtype): def infer_dtype(self, bboxes_dtype):
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name()) validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_) return (bboxes_dtype, mstype.int32, mstype.bool_)
...@@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer): ...@@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type return x_type
def infer_value(self, x): def infer_value(self, x):
...@@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer): ...@@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_dtype): def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype return x_dtype
...@@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer): ...@@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer):
return x_shape return x_shape
def infer_dtype(self, x_type): def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name()) validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type return x_type
......
...@@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive): ...@@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
Primitive.__init__(self, name) Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_shape) self.set_prim_type(prim_type.py_infer_shape)
def prim_name(self):
return self.__class__.__name__
def _clone(self): def _clone(self):
""" """
Deeply clones the primitive object. Deeply clones the primitive object.
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <vector> #include <vector>
#include "common/common_test.h" #include "common/common_test.h"
#include "parallel/strategy.h" #include "parallel/strategy.h"
#include "parallel/ops_info/elementary_function_info.h" #include "parallel/ops_info/arithmetic_info.h"
#include "parallel/device_manager.h" #include "parallel/device_manager.h"
#include "parallel/step_parallel.h" #include "parallel/step_parallel.h"
...@@ -56,14 +56,14 @@ void TestPowInfo::SetUp() { ...@@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
std::unordered_map<std::string, ValuePtr> attr; std::unordered_map<std::string, ValuePtr> attr;
Shapes inputs_shape = {{32, 64, 128}}; Shapes inputs_shape = {{32, 64, 128}, {32, 64, 128}};
Shapes outputs_shape = {{32, 64, 128}}; Shapes outputs_shape = {{32, 64, 128}};
pow = std::make_shared<PowInfo>("pow_info", inputs_shape, outputs_shape, attr); pow = std::make_shared<PowInfo>("pow_info", inputs_shape, outputs_shape, attr);
} }
TEST_F(TestPowInfo, InferDevMatrixShape1) { TEST_F(TestPowInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
...@@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) { ...@@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
} }
TEST_F(TestPowInfo, InferSliceShape1) { TEST_F(TestPowInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8}}; std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy); pow->Init(strategy);
...@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) { ...@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
} }
TEST_F(TestPowInfo, GetTensorLayout1) { TEST_F(TestPowInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8}}; std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str); StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy); pow->Init(strategy);
...@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) { ...@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
} }
TEST_F(TestPowInfo, GetForwardOp1) { TEST_F(TestPowInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
...@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) { ...@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
} }
TEST_F(TestPowInfo, GetMirrorOPs1) { TEST_F(TestPowInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy); pow->Init(strategy);
...@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) { ...@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
} }
TEST_F(TestPowInfo, CheckStrategy2) { TEST_F(TestPowInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}}; std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy); Status ret = pow->Init(strategy);
...@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) { ...@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
} }
TEST_F(TestPowInfo, CheckStrategy3) { TEST_F(TestPowInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{2, 4, 8}}; std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs); StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy); Status ret = pow->Init(strategy);
......
...@@ -82,9 +82,10 @@ def test_sqrt(): ...@@ -82,9 +82,10 @@ def test_sqrt():
def test_pow(): def test_pow():
""" test_pow """ """ test_pow """
input_tensor = Tensor(np.array([[2, 2], [3, 3]])) input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
power = Tensor(np.array(3.0, np.int64))
testpow = P.Pow() testpow = P.Pow()
expect = np.array([[8, 8], [27, 27]]) expect = np.array([[8, 8], [27, 27]])
result = testpow(input_tensor, 3.0) result = testpow(input_tensor, power)
assert np.all(result.asnumpy() == expect) assert np.all(result.asnumpy() == expect)
......
...@@ -224,11 +224,15 @@ test_case_math_ops = [ ...@@ -224,11 +224,15 @@ test_case_math_ops = [
'block': P.Minimum(), 'block': P.Minimum(),
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}), 'desc_bprop': [[2, 3, 3, 5]]}),
('Pow', { ('Pow_0', {
'block': P.Pow(), 'block': P.Pow(),
'desc_const': [2.0], 'desc_const': [2.0],
'desc_inputs': [[2, 3, 3, 5]], 'desc_inputs': [[2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}), 'desc_bprop': [[2, 3, 3, 5]]}),
('Pow_1', {
'block': P.Pow(),
'desc_inputs': [[3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Exp', { ('Exp', {
'block': P.Exp(), 'block': P.Exp(),
'desc_inputs': [[2, 3]], 'desc_inputs': [[2, 3]],
......
...@@ -59,7 +59,7 @@ def test_matmul_pow(): ...@@ -59,7 +59,7 @@ def test_matmul_pow():
context.set_auto_parallel_context(device_num=8, global_rank=0) context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2)) strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), ) strategy2 = ((4, 2), ())
net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
......
...@@ -117,6 +117,7 @@ def vm_impl_pow(self): ...@@ -117,6 +117,7 @@ def vm_impl_pow(self):
"""Generate vm_impl function for Pow.""" """Generate vm_impl function for Pow."""
def vm_impl(x, y): def vm_impl(x, y):
x = x.asnumpy() x = x.asnumpy()
y = y.asnumpy()
res = vm.power(x, y) res = vm.power(x, y)
return Tensor(res) return Tensor(res)
return vm_impl return vm_impl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册