未验证 提交 e6c3f64f 编写于 作者: C Chen Weihang 提交者: GitHub

Fix renorm op include error and format error (#38451)

* remove needless header

* remove needless header

* adjust header order
上级 bbe879fc
......@@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/renorm_op.h"
#include <algorithm>
#include <cstdio>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.h"
#include "paddle/fluid/operators/renorm_op.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "stdio.h"
namespace paddle {
namespace operators {
......@@ -60,7 +60,7 @@ __global__ void RenormKernelFunc3(int64_t size, T* dim_value, float p,
}
template <typename T>
__global__ void RenormKernelFunc4(T* x_data, T* out_data, int64_t size,
__global__ void RenormKernelFunc4(const T* x_data, T* out_data, int64_t size,
T* dim_value, int64_t dimension_each,
int64_t dim_divisor) {
int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x;
......@@ -74,8 +74,8 @@ __global__ void RenormKernelFunc4(T* x_data, T* out_data, int64_t size,
}
template <typename T>
__global__ void RenormGradKernelFunc1(T* x_data, T* dout_data, T* pow_value,
T* mul_value, int64_t size,
__global__ void RenormGradKernelFunc1(const T* x_data, const T* dout_data,
T* pow_value, T* mul_value, int64_t size,
int64_t dimension_each, float p,
int64_t dim_divisor) {
int64_t i = ((int64_t)blockIdx.x) * blockDim.x + threadIdx.x;
......@@ -87,8 +87,8 @@ __global__ void RenormGradKernelFunc1(T* x_data, T* dout_data, T* pow_value,
}
template <typename T>
__global__ void RenormGradKernelFunc2(T* x_data, T* dout_data, T* dx_data,
int64_t size, T* dim_value,
__global__ void RenormGradKernelFunc2(const T* x_data, const T* dout_data,
T* dx_data, int64_t size, T* dim_value,
T* dim_power_sum, T* weight_derivative,
int64_t dimension_each, float p,
float max_norm, int64_t dim_divisor) {
......@@ -100,8 +100,9 @@ __global__ void RenormGradKernelFunc2(T* x_data, T* dout_data, T* dx_data,
if (temp > max_norm) {
dim_power_sum[i] = pow(dim_value[i], (T)(-1.0 - 1.0 / p)) * -1 * max_norm;
dim_value[i] = max_norm / temp;
} else
} else {
dim_value[i] = 1.0;
}
}
__syncthreads();
if (i < size) {
......@@ -120,7 +121,7 @@ class CUDARenormKernel : public framework::OpKernel<T> {
const Tensor* x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
auto numel = x->numel();
T* x_data = (T*)x->data<T>();
const T* x_data = x->data<T>();
auto input_dims = x->dims();
float max_norm = context.Attr<float>("max_norm");
float p = context.Attr<float>("p");
......@@ -176,8 +177,8 @@ class CUDAGradRenormKernel : public framework::OpKernel<T> {
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto numel = d_out->numel();
T* dout_data = (T*)d_out->data<T>();
T* x_data = (T*)x->data<T>();
const T* dout_data = d_out->data<T>();
const T* x_data = x->data<T>();
auto input_dims = x->dims();
float max_norm = ctx.Attr<float>("max_norm");
float p = ctx.Attr<float>("p");
......@@ -234,4 +235,4 @@ REGISTER_OP_CUDA_KERNEL(renorm, ops::CUDARenormKernel<float>,
ops::CUDARenormKernel<double>);
REGISTER_OP_CUDA_KERNEL(renorm_grad, ops::CUDAGradRenormKernel<float>,
ops::CUDAGradRenormKernel<double>);
\ No newline at end of file
ops::CUDAGradRenormKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册