提交 2ccaec4f 编写于 作者: Z zchen0211

gather scatter cond

上级 3d09a654
......@@ -126,8 +126,7 @@ void CondOp::PrepareDataForSubnet(
dim[0] = index_tensors[i].dims()[0];
tensor_child->mutable_data<float>(dim, platform::CPUPlace());
CPUGather<float>(dev_ctx.GetPlace(), tensor_parent, &index_tensors[i],
tensor_child);
CPUGather<float>(dev_ctx, tensor_parent, &index_tensors[i], tensor_child);
}
}
......@@ -188,7 +187,7 @@ void CondOp::MergeDataFromSubnet(const framework::Scope& scope,
Variable* var_child = sub_scopes[i]->FindVar(output);
PADDLE_ENFORCE_NOT_NULL(var_child);
auto* tensor_child = &var_child->Get<LoDTensor>();
ScatterAssign<float>(dev_ctx.GetPlace(), tensor_child, &index_tensors[i],
ScatterAssign<float>(dev_ctx, tensor_child, &index_tensors[i],
tensor_parent);
}
}
......
......@@ -32,11 +32,11 @@ namespace operators {
* return: output tensor
*/
template <typename T>
void CPUGather(const platform::Place& place,
void CPUGather(const platform::DeviceContext& ctx,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(place));
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
......
......@@ -36,7 +36,7 @@ class GatherOpKernel : public framework::OpKernel<T> {
output->mutable_data<T>(ctx.GetPlace());
CPUGather<T>(ctx.GetPlace(), x, index, output);
CPUGather<T>(ctx.device_context(), x, index, output);
}
};
......@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
dxt.device(place) = dxt.constant(static_cast<T>(0));
ScatterAssign<T>(ctx.GetPlace(), dO, Index, dX);
ScatterAssign<T>(ctx.device_context(), dO, Index, dX);
}
};
......
......@@ -41,7 +41,9 @@ TEST(Gather, GatherData) {
int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());
CPUGather<int>(CPUPlace(), src, index, output);
auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
CPUGather<int>(ctx, src, index, output);
for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
......
......@@ -33,11 +33,11 @@ using Tensor = framework::Tensor;
* return: output tensor
*/
template <typename T>
void ScatterAssign(const platform::Place& place,
void ScatterAssign(const platform::DeviceContext& ctx,
const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
PADDLE_ENFORCE(platform::is_cpu_place(place));
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()));
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];
......
......@@ -37,7 +37,7 @@ class ScatterOpKernel : public framework::OpKernel<T> {
// In place output: Out = Ref, Out[Index] += Updates
Out->ShareDataWith<T>(*Ref);
// Apply ScatterUpdate: Out[index] += Updates[:]
ScatterAssign<T>(ctx.GetPlace(), Updates, Index, Out);
ScatterAssign<T>(ctx.device_context(), Updates, Index, Out);
}
};
......@@ -56,7 +56,7 @@ class ScatterGradientOpKernel : public framework::OpKernel<T> {
dRef->ShareDataWith<T>(*dOut);
dUpdates->mutable_data<T>(ctx.GetPlace());
// Gradient by Gather: dUpdates += dO[Index]
CPUGather<T>(ctx.GetPlace(), dOut, Index, dUpdates);
CPUGather<T>(ctx.device_context(), dOut, Index, dUpdates);
}
};
......
......@@ -40,7 +40,9 @@ TEST(scatter, ScatterUpdate) {
float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace());
ScatterAssign<float>(CPUPlace(), src, index, output);
auto* cpu_place = new paddle::platform::CPUPlace();
paddle::platform::CPUDeviceContext ctx(*cpu_place);
ScatterAssign<float>(ctx, src, index, output);
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册