未验证 提交 75aaa08a 编写于 作者: Q qipengh 提交者: GitHub

[MLU]add elementwise_pow op (#44215)

上级 176a8832
......@@ -122,6 +122,7 @@ enum BINARY_FUNCTOR {
DIVNONAN,
MAXIMUM,
MINIMUM,
POW,
};
template <BINARY_FUNCTOR func>
......@@ -171,6 +172,18 @@ inline void MLUBinary<MINIMUM>(const framework::ExecutionContext& ctx,
MLUCnnl::Minimum(ctx, in1_desc, in1, in2_desc, in2, out_desc, out);
}
template <>
inline void MLUBinary<POW>(const framework::ExecutionContext& ctx,
cnnlComputationPreference_t prefer,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t y_desc,
const void* y,
const cnnlTensorDescriptor_t out_desc,
void* out) {
MLUCnnl::Pow(ctx, prefer, x_desc, x, y_desc, y, out_desc, out);
}
template <BINARY_FUNCTOR Functor, typename T>
void MLUBinaryOp(const framework::ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X");
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mlu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class ElementwisePowMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
MLUBinaryOp<POW, T>(ctx);
}
};
template <typename T>
class ElementwisePowGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
auto place = ctx.GetPlace();
auto x_dims = x->dims();
auto y_dims = y->dims();
axis =
(axis < 0 ? std::abs(x_dims.size() - y_dims.size()) + axis + 1 : axis);
int max_dim = std::max(x_dims.size(), y_dims.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims,
y_dims,
x_dims_array.data(),
y_dims_array.data(),
out_dims_array.data(),
max_dim,
axis);
cnnlDataType_t data_type = ToCnnlDataType<T>();
MLUCnnlTensorDesc x_desc(max_dim, x_dims_array.data(), data_type);
MLUCnnlTensorDesc y_desc(max_dim, y_dims_array.data(), data_type);
MLUCnnlTensorDesc out_desc(max_dim, out_dims_array.data(), data_type);
auto dout_dims = dout->dims();
if (dx) {
// dx = dout * y * pow(x, y - 1);
Tensor one_dx(y->type());
one_dx.mutable_data<T>(phi::make_ddim(y_dims_array), place);
FillMLUTensorWithHostValue(ctx, static_cast<T>(1), &one_dx);
Tensor sub_dx(y->type());
sub_dx.mutable_data<T>(phi::make_ddim(y_dims_array), place);
MLUCnnlOpTensorDesc op_tensor_desc(
CNNL_OP_TENSOR_SUB, data_type, CNNL_NOT_PROPAGATE_NAN);
MLUCnnl::OpTensor(ctx,
op_tensor_desc.get(),
y_desc.get(),
GetBasePtr(y),
y_desc.get(),
GetBasePtr(&one_dx),
y_desc.get(),
GetBasePtr(&sub_dx),
data_type);
Tensor tmp_dx(x->type());
tmp_dx.mutable_data<T>(phi::make_ddim(out_dims_array), place);
MLUCnnl::Pow(ctx,
CNNL_COMPUTATION_HIGH_PRECISION,
x_desc.get(),
GetBasePtr(x),
y_desc.get(),
GetBasePtr(&sub_dx),
out_desc.get(),
GetBasePtr(&tmp_dx));
MLUCnnl::MulAx(ctx,
y_desc.get(),
GetBasePtr(y),
out_desc.get(),
GetBasePtr(&tmp_dx));
MLUCnnl::MulAx(ctx,
out_desc.get(),
GetBasePtr(dout),
out_desc.get(),
GetBasePtr(&tmp_dx));
if (x_dims != dout_dims) {
dx->mutable_data<T>(place);
std::vector<int> reduce_axes;
GetReduceAxes(axis, dout_dims, x_dims, &reduce_axes);
if (!reduce_axes.empty()) {
MLUCnnlReduceDesc reduction_desc(reduce_axes,
CNNL_REDUCE_ADD,
data_type,
CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES,
CNNL_32BIT_INDICES);
MLUCnnlTensorDesc dx_desc(*dx);
MLUCnnl::Reduce(ctx,
true /*need_workspace*/,
reduction_desc.get(),
nullptr,
out_desc.get(),
GetBasePtr(&tmp_dx),
0,
nullptr,
nullptr,
dx_desc.get(),
GetBasePtr(dx));
}
} else {
dx->ShareDataWith(tmp_dx);
}
}
if (dy) {
// dy = dout * log(x) * pow(x, y)
Tensor tmp_dy(y->type());
tmp_dy.mutable_data<T>(phi::make_ddim(out_dims_array), place);
MLUCnnl::Pow(ctx,
CNNL_COMPUTATION_HIGH_PRECISION,
x_desc.get(),
GetBasePtr(x),
y_desc.get(),
GetBasePtr(y),
out_desc.get(),
GetBasePtr(&tmp_dy));
Tensor log_x(x->type());
log_x.mutable_data<T>(x->dims(), place);
MLUCnnl::Log(ctx,
CNNL_COMPUTATION_HIGH_PRECISION,
CNNL_LOG_E,
x_desc.get(),
GetBasePtr(x),
x_desc.get(),
GetBasePtr(&log_x));
MLUCnnl::MulAx(ctx,
x_desc.get(),
GetBasePtr(&log_x),
out_desc.get(),
GetBasePtr(&tmp_dy));
MLUCnnl::MulAx(ctx,
out_desc.get(),
GetBasePtr(dout),
out_desc.get(),
GetBasePtr(&tmp_dy));
if (y_dims != dout_dims) {
dy->mutable_data<T>(place);
std::vector<int> reduce_axes;
GetReduceAxes(axis, dout_dims, y_dims, &reduce_axes);
if (!reduce_axes.empty()) {
MLUCnnlReduceDesc reduction_desc(reduce_axes,
CNNL_REDUCE_ADD,
data_type,
CNNL_NOT_PROPAGATE_NAN,
CNNL_REDUCE_NO_INDICES,
CNNL_32BIT_INDICES);
MLUCnnlTensorDesc dy_desc(*dy);
MLUCnnl::Reduce(ctx,
true /*need_workspace*/,
reduction_desc.get(),
nullptr,
out_desc.get(),
GetBasePtr(&tmp_dy),
0,
nullptr,
nullptr,
dy_desc.get(),
GetBasePtr(dy));
}
} else {
dy->ShareDataWith(tmp_dy);
}
}
if (!dx && !dy) {
PADDLE_THROW(platform::errors::Unavailable(
"Not support all outputs to be empty."));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(elementwise_pow,
ops::ElementwisePowMLUKernel<plat::float16>,
ops::ElementwisePowMLUKernel<float>);
REGISTER_OP_MLU_KERNEL(elementwise_pow_grad,
ops::ElementwisePowGradMLUKernel<plat::float16>,
ops::ElementwisePowGradMLUKernel<float>);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
paddle.enable_static()
SEED = 2022
def ComputeGrad(x, y, out, axis):
grad = 1 / out.size
shape_x = x.shape
shape_y = y.shape
shape_out = out.shape
reduce_axes_x = []
reduce_axes_y = []
if shape_x != shape_out:
if len(shape_x) < len(shape_out):
src_axis = axis
else:
src_axis = 0
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_x)) or (
shape_out[ax] > 1 and shape_x[ax - src_axis] == 1):
reduce_axes_x.append(ax)
if shape_y != shape_out:
if len(shape_y) < len(shape_out):
src_axis = axis
else:
src_axis = 0
for ax in range(len(shape_out)):
if (ax < src_axis or ax >= src_axis + len(shape_y)) or (
shape_out[ax] > 1 and shape_y[ax - src_axis] == 1):
reduce_axes_y.append(ax)
if len(reduce_axes_x) > 0:
for i in reduce_axes_x:
x = np.expand_dims(x, axis=i)
if len(reduce_axes_y) > 0:
for i in reduce_axes_y:
y = np.expand_dims(y, axis=i)
dx = y * np.power(x, y - 1) * grad
dy = np.log(x) * np.power(x, y) * grad
if len(reduce_axes_x) > 0:
for i, element in enumerate(reduce_axes_x):
dx = np.add.reduce(dx, element - i)
if len(reduce_axes_y) > 0:
for i, element in enumerate(reduce_axes_y):
dy = np.add.reduce(dy, element - i)
return dx, dy
class TestElementwisePow(OpTest):
def setUp(self):
self.set_mlu()
self.op_type = "elementwise_pow"
self.init_dtype()
self.init_input_output()
self.init_axis()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y)
}
self.attrs = {'axis': self.axis}
self.outputs = {'Out': self.out}
def set_mlu(self):
self.__class__.use_mlu = True
self.place = paddle.device.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(self.place)
def init_axis(self):
self.axis = -1
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.out = np.power(self.x, self.y)
def test_check_grad_normal(self):
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X', 'Y'],
'Out',
user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwisePowFp16(TestElementwisePow):
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.out = np.power(self.x, self.y)
def set_mlu(self):
self.__class__.use_mlu = True
# self.__class__.no_need_check_grad = True
self.place = paddle.device.MLUPlace(0)
def init_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5)
class TestElementwisePowOp_broadcast_0(TestElementwisePow):
def init_axis(self):
self.axis = 1
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [1, 11, 17]).astype(self.dtype)
self.out = np.power(self.x, self.y)
def test_check_grad_normal(self):
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X', 'Y'],
'Out',
user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwisePowOp_broadcast_1(TestElementwisePow):
def init_axis(self):
self.axis = 1
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(1, 2, [2, 100, 1]).astype(self.dtype)
self.y = np.random.uniform(1, 2, [100]).astype(self.dtype)
self.out = np.power(self.x, self.y.reshape(1, 100, 1))
def test_check_grad_normal(self):
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X', 'Y'],
'Out',
user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
class TestElementwisePowOp_broadcast_2(TestElementwisePow):
def init_axis(self):
self.axis = 0
def init_input_output(self):
np.random.seed(SEED)
self.x = np.random.uniform(0.1, 1, [100, 3, 1]).astype(self.dtype)
self.y = np.random.uniform(0.1, 1, [100]).astype(self.dtype)
self.out = np.power(self.x, self.y.reshape(100, 1, 1))
def test_check_grad_normal(self):
dx, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X', 'Y'],
'Out',
user_defined_grads=[dx, dy])
def test_check_grad_ingore_x(self):
_, dy = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['Y'],
'Out',
no_grad_set=set("X"),
user_defined_grads=[dy])
def test_check_grad_ingore_y(self):
dx, _ = ComputeGrad(self.x, self.y, self.out, self.axis)
self.check_grad_with_place(self.place, ['X'],
'Out',
no_grad_set=set("Y"),
user_defined_grads=[dx])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册