未验证 提交 02f66747 编写于 作者: 2 201716010711 提交者: GitHub

[AMP OP&Test] Fix scale kernel and perfect unit test (#50998)

上级 ad92a5c1
...@@ -21,22 +21,22 @@ limitations under the License. */ ...@@ -21,22 +21,22 @@ limitations under the License. */
namespace phi { namespace phi {
template <typename InT> template <typename DataT, typename ParamT>
struct ScaleFunctor { struct ScaleFunctor {
InT bias; ParamT bias;
InT scale; ParamT scale;
bool bias_after_scale; bool bias_after_scale;
ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) ScaleFunctor(ParamT scale_data, ParamT bias_data, bool is_bias_after_sacle)
: bias(bias_data), : bias(bias_data),
scale(scale_data), scale(scale_data),
bias_after_scale(is_bias_after_sacle) {} bias_after_scale(is_bias_after_sacle) {}
__device__ __forceinline__ InT operator()(const InT x) const { __device__ __forceinline__ DataT operator()(const DataT x) const {
if (bias_after_scale) { if (bias_after_scale) {
return scale * x + bias; return static_cast<DataT>(scale * static_cast<ParamT>(x) + bias);
} else { } else {
return scale * (x + bias); return static_cast<DataT>(scale * (static_cast<ParamT>(x) + bias));
} }
} }
}; };
...@@ -48,6 +48,7 @@ void ScaleKernel(const Context& dev_ctx, ...@@ -48,6 +48,7 @@ void ScaleKernel(const Context& dev_ctx,
float bias, float bias,
bool bias_after_scale, bool bias_after_scale,
DenseTensor* out) { DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
std::vector<const DenseTensor*> inputs; std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs; std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x); inputs.emplace_back(&x);
...@@ -60,7 +61,8 @@ void ScaleKernel(const Context& dev_ctx, ...@@ -60,7 +61,8 @@ void ScaleKernel(const Context& dev_ctx,
dev_ctx, dev_ctx,
inputs, inputs,
&outputs, &outputs,
ScaleFunctor<T>(scale.to<T>(), static_cast<T>(bias), bias_after_scale)); ScaleFunctor<T, MT>(
scale.to<MT>(), static_cast<MT>(bias), bias_after_scale));
} }
} // namespace phi } // namespace phi
......
...@@ -149,15 +149,11 @@ class TestScaleFp16Op(TestScaleOp): ...@@ -149,15 +149,11 @@ class TestScaleFp16Op(TestScaleOp):
def test_check_output(self): def test_check_output(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): self.check_output_with_place(place, check_eager=True)
self.check_output_with_place(place, atol=0.002, check_eager=True)
def test_check_grad(self): def test_check_grad(self):
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
if core.is_float16_supported(place): self.check_grad_with_place(place, ["X"], "Out", check_eager=True)
self.check_grad_with_place(
place, ["X"], "Out", max_relative_error=0.05, check_eager=True
)
class TestScaleBF16Op(OpTest): class TestScaleBF16Op(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册