未验证 提交 e6d397e6 编写于 作者: Y ykkk2333 提交者: GitHub

fix transformer bug, test=kunlun (#45927)

上级 9f5b0831
......@@ -65,7 +65,7 @@ void AdamDenseKernel(const Context& dev_ctx,
const float* beta1_const_pow_ptr = nullptr;
if (beta1_pow.place() == CPUPlace()) {
DenseTensor xpu_beta1_pow;
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, &xpu_beta1_pow);
phi::Copy(dev_ctx, beta1_pow, dev_ctx.GetPlace(), false, &xpu_beta1_pow);
if (xpu_beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta1_pow, &beta1_pow_ptr, dev_ctx);
......@@ -82,7 +82,7 @@ void AdamDenseKernel(const Context& dev_ctx,
const float* beta2_const_pow_ptr = nullptr;
if (beta2_pow.place() == CPUPlace()) {
DenseTensor xpu_beta2_pow;
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, &xpu_beta2_pow);
phi::Copy(dev_ctx, beta2_pow, dev_ctx.GetPlace(), false, &xpu_beta2_pow);
if (xpu_beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta2_pow, &beta2_pow_ptr, dev_ctx);
......
......@@ -20,6 +20,18 @@ limitations under the License. */
namespace phi {
template <typename T, typename Context>
inline T GetValue(const Context& dev_ctx, const DenseTensor& x) {
T value = static_cast<T>(0);
if (x.place() != CPUPlace()) {
DenseTensor cpu_x;
Copy(dev_ctx, x, CPUPlace(), true, &cpu_x);
value = cpu_x.data<T>()[0];
} else {
value = x.data<T>()[0];
}
return value;
}
template <typename T, typename Context>
void ArangeKernel(const Context& dev_ctx,
const DenseTensor& start,
......@@ -29,19 +41,9 @@ void ArangeKernel(const Context& dev_ctx,
auto place = dev_ctx.GetPlace();
auto cpu_place = phi::CPUPlace();
DenseTensor n_cpu;
n_cpu.Resize({start.numel()});
T* n_cpu_data = dev_ctx.template HostAlloc<T>(&n_cpu);
paddle::memory::Copy(
cpu_place, n_cpu_data, place, start.data<T>(), sizeof(T) * start.numel());
T start_value = n_cpu_data[0];
paddle::memory::Copy(
cpu_place, n_cpu_data, place, end.data<T>(), sizeof(T) * end.numel());
T end_value = n_cpu_data[0];
paddle::memory::Copy(
cpu_place, n_cpu_data, place, step.data<T>(), sizeof(T) * step.numel());
T step_value = n_cpu_data[0];
T start_value = GetValue<T, Context>(dev_ctx, start);
T end_value = GetValue<T, Context>(dev_ctx, end);
T step_value = GetValue<T, Context>(dev_ctx, step);
int64_t size = 0;
phi::funcs::GetSize(start_value, end_value, step_value, &size);
......@@ -50,7 +52,9 @@ void ArangeKernel(const Context& dev_ctx,
DenseTensor out_cpu;
out_cpu.Resize({out->numel()});
T* out_cpu_data = dev_ctx.template HostAlloc<T>(&out_cpu);
dev_ctx.template HostAlloc<T>(&out_cpu);
T* out_cpu_data = out_cpu.data<T>();
T value = start_value;
for (int64_t i = 0; i < size; ++i) {
out_cpu_data[i] = value;
......@@ -63,4 +67,8 @@ void ArangeKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
arange, XPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int, int64_t) {}
arange, XPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int, int64_t) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
}
......@@ -42,7 +42,7 @@ void GaussianRandomKernel(const Context& ctx,
for (int64_t i = 0; i < size; ++i) {
data_cpu[i] = dist(*engine);
}
paddle::memory::Copy(phi::XPUPlace(),
paddle::memory::Copy(ctx.GetPlace(),
data,
phi::CPUPlace(),
reinterpret_cast<void*>(data_cpu.get()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册