未验证 提交 02f66747 编写于 作者: 2 201716010711 提交者: GitHub

[AMP OP&Test] Fix scale kernel and perfect unit test (#50998)

上级 ad92a5c1
......@@ -21,22 +21,22 @@ limitations under the License. */
namespace phi {
template <typename InT>
template <typename DataT, typename ParamT>
struct ScaleFunctor {
InT bias;
InT scale;
ParamT bias;
ParamT scale;
bool bias_after_scale;
ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle)
ScaleFunctor(ParamT scale_data, ParamT bias_data, bool is_bias_after_sacle)
: bias(bias_data),
scale(scale_data),
bias_after_scale(is_bias_after_sacle) {}
__device__ __forceinline__ InT operator()(const InT x) const {
__device__ __forceinline__ DataT operator()(const DataT x) const {
if (bias_after_scale) {
return scale * x + bias;
return static_cast<DataT>(scale * static_cast<ParamT>(x) + bias);
} else {
return scale * (x + bias);
return static_cast<DataT>(scale * (static_cast<ParamT>(x) + bias));
}
}
};
......@@ -48,6 +48,7 @@ void ScaleKernel(const Context& dev_ctx,
float bias,
bool bias_after_scale,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
......@@ -60,7 +61,8 @@ void ScaleKernel(const Context& dev_ctx,
dev_ctx,
inputs,
&outputs,
ScaleFunctor<T>(scale.to<T>(), static_cast<T>(bias), bias_after_scale));
ScaleFunctor<T, MT>(
scale.to<MT>(), static_cast<MT>(bias), bias_after_scale));
}
} // namespace phi
......
......@@ -149,15 +149,11 @@ class TestScaleFp16Op(TestScaleOp):
def test_check_output(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=0.002, check_eager=True)
self.check_output_with_place(place, check_eager=True)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.05, check_eager=True
)
self.check_grad_with_place(place, ["X"], "Out", check_eager=True)
class TestScaleBF16Op(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册