未验证 提交 ab6f8745 编写于 作者: W wuhuanzhou 提交者: GitHub

remove thrust include files (#32395)

* remove thrust includes, test=develop

* fix compilation error, test=develop

* fix compilation of truncated_gaussian_random_op, test=develop
上级 2194ad15
......@@ -14,16 +14,11 @@ limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#endif
#include <glog/logging.h>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/mixed_vector.h"
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/diag_embed_op.h"
......
......@@ -11,6 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
......
......@@ -11,7 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/for_each.h>
#include <thrust/host_vector.h>
#include <thrust/tuple.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/modified_huber_loss_op.h"
......
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#include "paddle/fluid/operators/trace_op.h"
......
......@@ -12,25 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include <limits>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct TruncatedNormal {
struct GPUTruncatedNormal {
T mean, std;
T a_normal_cdf;
T b_normal_cdf;
unsigned int seed;
T numeric_min;
__host__ __device__ TruncatedNormal(T mean, T std, T numeric_min, int seed)
__host__ __device__ GPUTruncatedNormal(T mean, T std, T numeric_min, int seed)
: mean(mean), std(std), seed(seed), numeric_min(numeric_min) {
a_normal_cdf = (1.0 + erff(-2.0 / sqrtf(2.0))) / 2.0;
b_normal_cdf = (1.0 + erff(2.0 / sqrtf(2.0))) / 2.0;
......@@ -110,10 +113,10 @@ class GPUTruncatedGaussianRandomKernel : public framework::OpKernel<T> {
TruncatedNormalOffset<T>(mean, std, std::numeric_limits<T>::min(),
seed_offset.first, gen_offset));
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
TruncatedNormal<T>(mean, std, std::numeric_limits<T>::min(), seed));
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
GPUTruncatedNormal<T>(
mean, std, std::numeric_limits<T>::min(), seed));
}
}
};
......
......@@ -11,9 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册