提交 913afe18 编写于 作者: Z Zhen Wang

update RowwiseAdd&ClearTensor

上级 b8553fdc
......@@ -135,7 +135,7 @@ template <typename T>
struct ClearTensor<CPU, T> {
void operator()(framework::Tensor *tensor) {
auto size = tensor->numel();
auto *tensor_data = tensor->data<float>();
auto *tensor_data = tensor->data<T>();
memset((void *)tensor_data, 0, sizeof(T) * size); // NOLINT
}
};
......@@ -151,9 +151,9 @@ struct RowwiseAdd<CPU, T> {
PADDLE_MOBILE_ENFORCE((output->dims() == in_dims),
"output->dims() must be equal to in_dims.");
auto *input_data = input.data<float>();
auto *out_data = output->data<float>();
auto *vec_data = vector.data<float>();
auto *input_data = input.data<T>();
auto *out_data = output->data<T>();
auto *vec_data = vector.data<T>();
for (int64_t i = 0; i < in_dims[0]; ++i) {
for (int64_t j = 0; j < size; ++j) {
out_data[i * size + j] = input_data[i * size + j] + vec_data[j];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册