提交 5841fe01 编写于 作者: B buxue

Support pow's second input could be tensor and fix bug in bprop of pow

上级 7cec2852
......@@ -98,6 +98,13 @@ class FloorDivInfo : public ArithmeticBase {
~FloorDivInfo() override = default;
};
class PowInfo : public ArithmeticBase {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ArithmeticBase(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
};
class GreaterInfo : public ArithmeticBase {
public:
GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape,
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/elementary_function_info.h"
namespace mindspore {
namespace parallel {
Status PowInfo::InferMirrorOps() {
mirror_ops_.clear();
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}
OperatorVector mirror_op;
OperatorVector op_for_value;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
mirror_ops_.push_back(op_for_value);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
......@@ -27,16 +27,6 @@
namespace mindspore {
namespace parallel {
class PowInfo : public ActivationOther {
public:
PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
: ActivationOther(name, inputs_shape, outputs_shape, attrs) {}
~PowInfo() override = default;
protected:
Status InferMirrorOps() override;
};
class ExpInfo : public ActivationOther {
public:
ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs)
......
......@@ -58,7 +58,7 @@ class _PoolNd(Cell):
pass
def extend_repr(self):
return 'kernel_size={kernel_size}, strides={strides}, pad_mode={pad_mode}'.format(**self.__dict__)
return 'kernel_size={kernel_size}, stride={stride}, pad_mode={pad_mode}'.format(**self.__dict__)
class MaxPool2d(_PoolNd):
......
......@@ -336,14 +336,13 @@ def get_bprop_log(self):
@bprop_getters.register(P.Pow)
def get_bprop_pow(self):
"""Grad definition for `Pow` operation."""
pow_ = P.Pow()
cast = P.Cast()
dtype = P.DType()
pow_op = P.Pow()
ln = P.Log()
def bprop(x, power, out, dout):
g = cast(F.tuple_to_array((power,)), dtype(x)) * pow_(x, power-1.0)
dx = g * dout
return dx, 0
dx = power * pow_op(x, power - 1.0) * dout
dpower = pow_op(x, power) * ln(x) * dout
return dx, dpower
return bprop
......
......@@ -1097,7 +1097,7 @@ class ArgMaxWithValue(PrimitiveWithInfer):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype):
......@@ -1143,7 +1143,7 @@ class ArgMinWithValue(PrimitiveWithInfer):
axis = self.axis
x_rank = len(x_shape)
validator.check_int_range("axis", axis, -x_rank, x_rank, Rel.INC_LEFT)
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.prim_name())
ouput_shape = _infer_shape_reduce(x_shape, self.axis, self.keep_dims, self.name)
return ouput_shape, ouput_shape
def infer_dtype(self, x_dtype):
......
......@@ -74,7 +74,7 @@ class _BinaryOp(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'y'], outputs=['output'])
def infer_shape(self, x_shape, y_shape):
return _get_broadcast_shape(x_shape, y_shape, self.prim_name())
return _get_broadcast_shape(x_shape, y_shape, self.name)
class _MathBinaryOp(_BinaryOp):
......@@ -89,7 +89,7 @@ class _MathBinaryOp(_BinaryOp):
return x_dtype
def infer_dtype(self, x_dtype, y_dtype):
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.prim_name())
return _MathBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type, self.name)
class TensorAdd(_MathBinaryOp):
......@@ -158,7 +158,7 @@ class AssignAdd(PrimitiveWithInfer):
def infer_dtype(self, variable, value):
args = {"value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value
......@@ -201,7 +201,7 @@ class AssignSub(PrimitiveWithInfer):
def infer_dtype(self, variable, value):
args = {"value": value}
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.prim_name())
validator.check_scalar_or_tensor_type_same(args, mstype.number_type, self.name)
return value
......@@ -222,16 +222,16 @@ class _Reduce(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, keep_dims=False):
"""init Reduce"""
validator.check_value_type('keep_dims', keep_dims, [bool], self.prim_name())
validator.check_value_type('keep_dims', keep_dims, [bool], self.name)
self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y'])
def do_infer(self, input_x, axis, valid_dtype=mstype.number_type):
axis_v = axis['value']
input_shp = input_x['shape']
args = {'input_x': input_x['dtype']}
validator.check_tensor_type_same(args, valid_dtype, self.prim_name())
validator.check_tensor_type_same(args, valid_dtype, self.name)
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.prim_name())
input_shp = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
return {'shape': input_shp,
'dtype': input_x['dtype'],
'value': None}
......@@ -466,7 +466,7 @@ class CumProd(PrimitiveWithInfer):
"""
@prim_attr_register
def __init__(self, exclusive=False, reverse=False):
cls_name = self.prim_name()
cls_name = self.name
self.exclusive = validator.check_value_type("exclusive", exclusive, [bool], cls_name)
self.reverse = validator.check_value_type("reverse", reverse, [bool], cls_name)
......@@ -474,7 +474,7 @@ class CumProd(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type, axis_type):
cls_name = self.prim_name()
cls_name = self.name
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, cls_name)
validator.check_subclass("axis", axis_type, mstype.int_, cls_name)
return x_type
......@@ -510,7 +510,7 @@ class MatMul(PrimitiveWithInfer):
def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True
cls_name = self.prim_name()
cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
......@@ -521,7 +521,7 @@ class MatMul(PrimitiveWithInfer):
def infer_shape(self, x, y):
self.check_shape_size(x, y)
cls_name = self.prim_name()
cls_name = self.name
# expected dimension of x, y, x:[...,a,b] y:[..., c,d], the dim size should be the same except the last two
for i in range(len(x) - 2):
if x[i] != y[i]:
......@@ -546,7 +546,7 @@ class MatMul(PrimitiveWithInfer):
def infer_dtype(self, x, y):
args = {"x": x, "y": y}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.prim_name())
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
return x
......@@ -590,7 +590,7 @@ class BatchMatMul(MatMul):
def __init__(self, transpose_a=False, transpose_b=False):
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
self.__setattr_flag__ = True
cls_name = self.prim_name()
cls_name = self.name
validator.check_value_type("transpose_a", transpose_a, [bool], cls_name)
validator.check_value_type("transpose_b", transpose_b, [bool], cls_name)
......@@ -628,13 +628,13 @@ class CumSum(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, exclusive=False, reverse=False):
"""init cumsum"""
cls_name = self.prim_name()
cls_name = self.name
validator.check_value_type('exclusive', exclusive, [bool], cls_name)
validator.check_value_type('reverse', reverse, [bool], cls_name)
self.init_prim_io_names(inputs=['x', 'axis'], outputs=['y'])
def __infer__(self, x, axis):
cls_name = self.prim_name()
cls_name = self.name
x_shp = x['shape']
validator.check_value_type('axis', axis['value'], [int], cls_name)
valid_types = [mstype.uint8, mstype.int8, mstype.int32, mstype.float16, mstype.float32]
......@@ -679,7 +679,7 @@ class AddN(PrimitiveWithInfer):
self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
def infer_shape(self, inputs):
cls_name = self.prim_name()
cls_name = self.name
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
self.add_prim_attr('n', len(inputs))
shp0 = inputs[0]
......@@ -688,7 +688,7 @@ class AddN(PrimitiveWithInfer):
return shp0
def infer_dtype(self, inputs):
cls_name = self.prim_name()
cls_name = self.name
validator.check_value_type("inputs", inputs, [tuple, list], cls_name)
validator.check_integer("inputs", len(inputs), 1, Rel.GE, cls_name)
args = {}
......@@ -718,7 +718,7 @@ class Neg(PrimitiveWithInfer):
return input_x
def infer_dtype(self, input_x):
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x
......@@ -809,7 +809,7 @@ class Square(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
......@@ -838,7 +838,7 @@ class Rsqrt(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
......@@ -867,7 +867,7 @@ class Sqrt(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({"x": x_type}, mstype.number_type, self.name)
return x_type
......@@ -897,14 +897,29 @@ class Reciprocal(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.prim_name())
validator.check_subclass("x", x, mstype.tensor, self.name)
return x
class Pow(PrimitiveWithInfer):
class Pow(_MathBinaryOp):
"""
Computes a tensor to the power of the second input.
The first input must be a tensor, and the second input should be a tensor or a number.
When the inputs are two tensors, the shapes of them could be broadcast,
and the data types of them should be the same.
When the inputs are one tensor and one scalar, the scalar could not be a parameter,
only could be a constant, and the type of the scalar is the same as the data type of the tensor.
Inputs:
- **input_x** (Union[Tensor]) - The first input is a tensor whose data type is number.
- **input_y** (Union[Tensor, Number]) - The second input is a tensor whose data type is same as 'input_x' or
a number.
Outputs:
Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'.
Inputs:
- **input_x** (Tensor) - The input tensor.
- **input_y** (Union[Tensor, Number]) - The exponent part. If exponent is a tensor, its shape must be able to
......@@ -927,17 +942,6 @@ class Pow(PrimitiveWithInfer):
[1.0, 16.0, 64.0]
"""
@prim_attr_register
def __init__(self):
"""init Multiply"""
def infer_shape(self, x, power):
return x
def infer_dtype(self, x, power):
validator.check_tensor_type_same({"x": x}, mstype.number_type, self.prim_name())
return x
class Exp(PrimitiveWithInfer):
"""
......@@ -965,7 +969,7 @@ class Exp(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_subclass("x", x_type, mstype.tensor, self.prim_name())
validator.check_subclass("x", x_type, mstype.tensor, self.name)
return x_type
......@@ -994,7 +998,7 @@ class Log(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_subclass("x", x, mstype.tensor, self.prim_name())
validator.check_subclass("x", x, mstype.tensor, self.name)
return x
......@@ -1176,7 +1180,7 @@ class Floor(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.prim_name())
validator.check_tensor_type_same({"x": x_dtype}, mstype.float_type, self.name)
return x_dtype
......@@ -1231,7 +1235,7 @@ class Acosh(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
......@@ -1247,7 +1251,7 @@ class _LogicBinaryOp(_BinaryOp):
return mstype.tensor_type(mstype.bool_)
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.prim_name())
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, prim_name=self.name)
class Equal(_LogicBinaryOp):
......@@ -1283,7 +1287,7 @@ class Equal(_LogicBinaryOp):
"""
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
class EqualCount(PrimitiveWithInfer):
......@@ -1318,7 +1322,7 @@ class EqualCount(PrimitiveWithInfer):
def infer_dtype(self, x_dtype, y_dtype):
args = {'x': x_dtype, 'y': y_dtype}
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.prim_name())
validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name)
return x_dtype
......@@ -1355,7 +1359,7 @@ class NotEqual(_LogicBinaryOp):
"""
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.prim_name())
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
class Greater(_LogicBinaryOp):
......@@ -1491,7 +1495,7 @@ class LogicalNot(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.prim_name())
validator.check_tensor_type_same({"x": x_dtype}, [mstype.bool_], self.name)
return mstype.tensor_type(mstype.bool_)
......@@ -1521,7 +1525,7 @@ class LogicalAnd(_LogicBinaryOp):
"""
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
class LogicalOr(_LogicBinaryOp):
......@@ -1550,7 +1554,7 @@ class LogicalOr(_LogicBinaryOp):
"""
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.prim_name())
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,), self.name)
class IsNan(PrimitiveWithInfer):
"""
......@@ -1699,13 +1703,13 @@ class NPUGetFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape):
cls_name = self.prim_name()
cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8]
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
return mstype.float32
......@@ -1741,13 +1745,13 @@ class NPUClearFloatStatus(PrimitiveWithInfer):
self.add_prim_attr("_side_effect_flag", True)
def infer_shape(self, x_shape):
cls_name = self.prim_name()
cls_name = self.name
validator.check_integer("len(x_shape)", len(x_shape), 1, Rel.EQ, cls_name)
validator.check_integer("x_shape[0]", x_shape[0], 8, Rel.EQ, cls_name)
return [8]
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.prim_name())
validator.check_tensor_type_same({'x': x_dtype}, [mstype.float32], self.name)
return mstype.float32
......@@ -1775,7 +1779,7 @@ class Cos(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
......@@ -1803,7 +1807,7 @@ class ACos(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
......@@ -1831,7 +1835,7 @@ class Sin(PrimitiveWithInfer):
return x
def infer_dtype(self, x):
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({'x': x}, mstype.number_type, self.name)
return x
......@@ -1876,11 +1880,11 @@ class NMSWithMask(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, iou_threshold=0.5):
"""Init NMSWithMask"""
validator.check_value_type("iou_threshold", iou_threshold, [float], self.prim_name())
validator.check_value_type("iou_threshold", iou_threshold, [float], self.name)
self.init_prim_io_names(inputs=['bboxes'], outputs=['selected_boxes', 'selected_idx', 'selected_mask'])
def infer_shape(self, bboxes_shape):
cls_name = self.prim_name()
cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
......@@ -1888,7 +1892,7 @@ class NMSWithMask(PrimitiveWithInfer):
return (bboxes_shape, (num,), (num,))
def infer_dtype(self, bboxes_dtype):
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.prim_name())
validator.check_tensor_type_same({"bboxes": bboxes_dtype}, [mstype.float16, mstype.float32], self.name)
return (bboxes_dtype, mstype.int32, mstype.bool_)
......@@ -1917,7 +1921,7 @@ class Abs(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
def infer_value(self, x):
......@@ -1959,7 +1963,7 @@ class Sign(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({'x': x_dtype}, mstype.number_type, self.name)
return x_dtype
......@@ -1988,7 +1992,7 @@ class Round(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.prim_name())
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
return x_type
......
......@@ -194,9 +194,6 @@ class PrimitiveWithInfer(Primitive):
Primitive.__init__(self, name)
self.set_prim_type(prim_type.py_infer_shape)
def prim_name(self):
return self.__class__.__name__
def _clone(self):
"""
Deeply clones the primitive object.
......
......@@ -19,7 +19,7 @@
#include <vector>
#include "common/common_test.h"
#include "parallel/strategy.h"
#include "parallel/ops_info/elementary_function_info.h"
#include "parallel/ops_info/arithmetic_info.h"
#include "parallel/device_manager.h"
#include "parallel/step_parallel.h"
......@@ -56,14 +56,14 @@ void TestPowInfo::SetUp() {
std::unordered_map<std::string, ValuePtr> attr;
Shapes inputs_shape = {{32, 64, 128}};
Shapes inputs_shape = {{32, 64, 128}, {32, 64, 128}};
Shapes outputs_shape = {{32, 64, 128}};
pow = std::make_shared<PowInfo>("pow_info", inputs_shape, outputs_shape, attr);
}
TEST_F(TestPowInfo, InferDevMatrixShape1) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy);
......@@ -74,7 +74,7 @@ TEST_F(TestPowInfo, InferDevMatrixShape1) {
}
TEST_F(TestPowInfo, InferSliceShape1) {
std::vector<Dimensions> str = {{2, 4, 8}};
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy);
......@@ -95,7 +95,7 @@ TEST_F(TestPowInfo, InferSliceShape1) {
}
TEST_F(TestPowInfo, GetTensorLayout1) {
std::vector<Dimensions> str = {{2, 4, 8}};
std::vector<Dimensions> str = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, str);
pow->Init(strategy);
......@@ -116,7 +116,7 @@ TEST_F(TestPowInfo, GetTensorLayout1) {
}
TEST_F(TestPowInfo, GetForwardOp1) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy);
......@@ -127,7 +127,7 @@ TEST_F(TestPowInfo, GetForwardOp1) {
}
TEST_F(TestPowInfo, GetMirrorOPs1) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
pow->Init(strategy);
......@@ -147,7 +147,7 @@ TEST_F(TestPowInfo, CheckStrategy1) {
}
TEST_F(TestPowInfo, CheckStrategy2) {
std::vector<Dimensions> inputs = {{2, 4, 8, 16}};
std::vector<Dimensions> inputs = {{2, 4, 8, 16}, {2, 4, 8, 16}};
StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy);
......@@ -155,7 +155,7 @@ TEST_F(TestPowInfo, CheckStrategy2) {
}
TEST_F(TestPowInfo, CheckStrategy3) {
std::vector<Dimensions> inputs = {{2, 4, 8}};
std::vector<Dimensions> inputs = {{2, 4, 8}, {2, 4, 8}};
StrategyPtr strategy = NewStrategy(0, inputs);
Status ret = pow->Init(strategy);
......
......@@ -82,9 +82,10 @@ def test_sqrt():
def test_pow():
""" test_pow """
input_tensor = Tensor(np.array([[2, 2], [3, 3]]))
power = Tensor(np.array(3.0, np.int64))
testpow = P.Pow()
expect = np.array([[8, 8], [27, 27]])
result = testpow(input_tensor, 3.0)
result = testpow(input_tensor, power)
assert np.all(result.asnumpy() == expect)
......
......@@ -224,11 +224,15 @@ test_case_math_ops = [
'block': P.Minimum(),
'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Pow', {
('Pow_0', {
'block': P.Pow(),
'desc_const': [2.0],
'desc_inputs': [[2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Pow_1', {
'block': P.Pow(),
'desc_inputs': [[3, 5], [2, 3, 3, 5]],
'desc_bprop': [[2, 3, 3, 5]]}),
('Exp', {
'block': P.Exp(),
'desc_inputs': [[2, 3]],
......
......@@ -59,7 +59,7 @@ def test_matmul_pow():
context.set_auto_parallel_context(device_num=8, global_rank=0)
strategy1 = ((2, 2), (2, 2))
strategy2 = ((4, 2), )
strategy2 = ((4, 2), ())
net = GradWrap(NetWithLoss(Net(strategy1, strategy2)))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
......
......@@ -117,6 +117,7 @@ def vm_impl_pow(self):
"""Generate vm_impl function for Pow."""
def vm_impl(x, y):
x = x.asnumpy()
y = y.asnumpy()
res = vm.power(x, y)
return Tensor(res)
return vm_impl
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册