提交 ce71c179 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!645 auto parallel prelu operator support broadcast

Merge pull request !645 from yao_yf/auto_parallel_prelu_support_broadcast
......@@ -32,7 +32,7 @@ namespace parallel {
* prelu has 2 input
* A: A float tensor of shape [NCHW] representing the output of the preview layer.
* w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input.
* the strategy of w should equal to the channel dimension of strategy of A
* the strategy of w should equal to the channel dimension of strategy of A, or equal to 1
*/
Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
......@@ -52,7 +52,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) {
}
return FAILED;
}
if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0]) {
if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0] && inputs_shape_[1][0] != 1) {
if (is_auto_parallel_) {
MS_LOG(DEBUG) << name_ << ": Invalid channel strategy.";
} else {
......@@ -107,7 +107,11 @@ Status PReLUInfo::InferTensorMap() {
}
TensorMap param_tensor_map;
param_tensor_map.push_back(input_tensor_map.at(1));
if (inputs_shape_[1][0] == 1) {
param_tensor_map.push_back(-1);
} else {
param_tensor_map.push_back(input_tensor_map.at(1));
}
inputs_tensor_map_.push_back(input_tensor_map);
inputs_tensor_map_.push_back(param_tensor_map);
outputs_tensor_map_.push_back(input_tensor_map);
......
......@@ -166,3 +166,21 @@ def test_prelu_parallel_success4():
w = Tensor(np.random.rand(16),dtype=ms.float32)
net = GradWrap(NetWithLoss(Net(strategy)))
_executor.compile(net, x, w)
def test_prelu_parallel_success5():
class Net(nn.Cell):
def __init__(self, strategy):
super().__init__()
self.prelu = P.PReLU().set_strategy(strategy)
def construct(self, x, y):
out = self.prelu(x, y)
return out
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=64, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy = ((2, 4, 4, 2), (1, ))
x = Tensor(np.random.rand(4, 16, 32, 64),dtype=ms.float32)
w = Tensor(np.random.rand(1),dtype=ms.float32)
net = GradWrap(NetWithLoss(Net(strategy)))
_executor.compile(net, x, w)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册