提交 5cf395be 编写于 作者: Y Yu Yang

Fix bug in uts

上级 a6fbf7ec
......@@ -319,7 +319,9 @@ TEST(Tensor, FromAndToStream) {
TensorToStream(oss, gpu_tensor, gpu_ctx);
std::istringstream iss(oss.str());
TensorFromStream(iss, &dst_tensor, gpu_ctx);
TensorFromStream(
iss, &dst_tensor,
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 6; ++i) {
......
......@@ -341,7 +341,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
......
......@@ -21,42 +21,38 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
TEST(scatter, ScatterUpdate) {
// using namespace paddle::framework;
// using namespace paddle::platform;
// using namespace paddle::operators;
paddle::framework::Tensor* src = new paddle::framework::Tensor();
paddle::framework::Tensor* index = new paddle::framework::Tensor();
paddle::framework::Tensor* output = new paddle::framework::Tensor();
float* p_src = nullptr;
int* p_index = nullptr;
p_src = src->mutable_data<float>(paddle::framework::make_ddim({1, 4}),
paddle::platform::CPUPlace());
p_index = index->mutable_data<int>(paddle::framework::make_ddim({1}),
paddle::platform::CPUPlace());
for (size_t i = 0; i < 4; ++i) p_src[i] = static_cast<float>(i);
paddle::framework::Tensor src;
paddle::framework::Tensor index;
paddle::framework::Tensor output;
auto* p_src = src.mutable_data<float>(paddle::framework::make_ddim({1, 4}),
paddle::platform::CPUPlace());
auto* p_index = index.mutable_data<int>(paddle::framework::make_ddim({1}),
paddle::platform::CPUPlace());
for (size_t i = 0; i < 4; ++i) {
p_src[i] = static_cast<float>(i);
}
p_index[0] = 1;
float* p_output = output->mutable_data<float>(
auto* p_output = output.mutable_data<float>(
paddle::framework::make_ddim({4, 4}), paddle::platform::CPUPlace());
for (int64_t i = 0; i < output.numel(); ++i) {
p_output[i] = 0;
}
auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
paddle::operators::ScatterAssign<float>(ctx, *src, *index, output);
paddle::operators::ScatterAssign<float>(ctx, src, index, &output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f);
for (size_t i = 4; i < 8; ++i) {
EXPECT_EQ(p_output[i], static_cast<float>(i - 4));
}
for (size_t i = 4; i < 8; ++i)
EXPECT_EQ(output->data<float>()[i], static_cast<float>(i - 4));
EXPECT_EQ(output.data<float>()[i], static_cast<float>(i - 4));
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], 0.0f);
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f);
delete src;
delete index;
delete output;
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f);
}
......@@ -18,8 +18,6 @@ limitations under the License. */
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/transform.h"
namespace {
template <typename T>
class Scale {
public:
......@@ -36,8 +34,6 @@ class Multiply {
HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
};
} // namespace
using paddle::memory::Alloc;
using paddle::memory::Copy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册