提交 5cf395be 编写于 作者: Y Yu Yang

Fix bug in uts

上级 a6fbf7ec
...@@ -319,7 +319,9 @@ TEST(Tensor, FromAndToStream) { ...@@ -319,7 +319,9 @@ TEST(Tensor, FromAndToStream) {
TensorToStream(oss, gpu_tensor, gpu_ctx); TensorToStream(oss, gpu_tensor, gpu_ctx);
std::istringstream iss(oss.str()); std::istringstream iss(oss.str());
TensorFromStream(iss, &dst_tensor, gpu_ctx); TensorFromStream(
iss, &dst_tensor,
*platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace()); int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
......
...@@ -341,7 +341,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") ...@@ -341,7 +341,7 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency") set(GLOB_DISTRIBUTE_DEPS ${DISTRIBUTE_DEPS} CACHE INTERNAL "distributed dependency")
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor) cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op) cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
......
...@@ -21,42 +21,38 @@ limitations under the License. */ ...@@ -21,42 +21,38 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
TEST(scatter, ScatterUpdate) { TEST(scatter, ScatterUpdate) {
// using namespace paddle::framework; paddle::framework::Tensor src;
// using namespace paddle::platform; paddle::framework::Tensor index;
// using namespace paddle::operators; paddle::framework::Tensor output;
paddle::framework::Tensor* src = new paddle::framework::Tensor(); auto* p_src = src.mutable_data<float>(paddle::framework::make_ddim({1, 4}),
paddle::framework::Tensor* index = new paddle::framework::Tensor(); paddle::platform::CPUPlace());
paddle::framework::Tensor* output = new paddle::framework::Tensor(); auto* p_index = index.mutable_data<int>(paddle::framework::make_ddim({1}),
paddle::platform::CPUPlace());
float* p_src = nullptr;
int* p_index = nullptr; for (size_t i = 0; i < 4; ++i) {
p_src = src->mutable_data<float>(paddle::framework::make_ddim({1, 4}), p_src[i] = static_cast<float>(i);
paddle::platform::CPUPlace()); }
p_index = index->mutable_data<int>(paddle::framework::make_ddim({1}),
paddle::platform::CPUPlace());
for (size_t i = 0; i < 4; ++i) p_src[i] = static_cast<float>(i);
p_index[0] = 1; p_index[0] = 1;
float* p_output = output->mutable_data<float>( auto* p_output = output.mutable_data<float>(
paddle::framework::make_ddim({4, 4}), paddle::platform::CPUPlace()); paddle::framework::make_ddim({4, 4}), paddle::platform::CPUPlace());
for (int64_t i = 0; i < output.numel(); ++i) {
p_output[i] = 0;
}
auto* cpu_place = new paddle::platform::CPUPlace(); auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place); paddle::platform::CPUDeviceContext ctx(*cpu_place);
paddle::operators::ScatterAssign<float>(ctx, *src, *index, output); paddle::operators::ScatterAssign<float>(ctx, src, index, &output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], 0.0f);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f); for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f);
for (size_t i = 4; i < 8; ++i) { for (size_t i = 4; i < 8; ++i) {
EXPECT_EQ(p_output[i], static_cast<float>(i - 4)); EXPECT_EQ(p_output[i], static_cast<float>(i - 4));
} }
for (size_t i = 4; i < 8; ++i) for (size_t i = 4; i < 8; ++i)
EXPECT_EQ(output->data<float>()[i], static_cast<float>(i - 4)); EXPECT_EQ(output.data<float>()[i], static_cast<float>(i - 4));
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], 0.0f); for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], 0.0f);
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data<float>()[i], 0.0f); for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output.data<float>()[i], 0.0f);
delete src;
delete index;
delete output;
} }
...@@ -18,8 +18,6 @@ limitations under the License. */ ...@@ -18,8 +18,6 @@ limitations under the License. */
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace {
template <typename T> template <typename T>
class Scale { class Scale {
public: public:
...@@ -36,8 +34,6 @@ class Multiply { ...@@ -36,8 +34,6 @@ class Multiply {
HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; }
}; };
} // namespace
using paddle::memory::Alloc; using paddle::memory::Alloc;
using paddle::memory::Copy; using paddle::memory::Copy;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册