未验证 提交 5728dffb 编写于 作者: C Chen Weihang 提交者: GitHub

fix assign typo (#41005)

上级 55f9b71a
...@@ -27,7 +27,7 @@ namespace operators { ...@@ -27,7 +27,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
typename std::enable_if<std::is_same<T, bool>::value>::type CopyVecotorToTensor( typename std::enable_if<std::is_same<T, bool>::value>::type CopyVectorToTensor(
const char* value_name, framework::Tensor* out, const char* value_name, framework::Tensor* out,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
// If attribute value dtype is vector<bool>, it will be converted to // If attribute value dtype is vector<bool>, it will be converted to
...@@ -48,8 +48,8 @@ typename std::enable_if<std::is_same<T, bool>::value>::type CopyVecotorToTensor( ...@@ -48,8 +48,8 @@ typename std::enable_if<std::is_same<T, bool>::value>::type CopyVecotorToTensor(
} }
template <typename T> template <typename T>
typename std::enable_if<!std::is_same<T, bool>::value>::type typename std::enable_if<!std::is_same<T, bool>::value>::type CopyVectorToTensor(
CopyVecotorToTensor(const char* value_name, framework::Tensor* out, const char* value_name, framework::Tensor* out,
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
auto values = ctx.Attr<std::vector<T>>(value_name); auto values = ctx.Attr<std::vector<T>>(value_name);
framework::TensorFromVector(values, ctx.device_context(), out); framework::TensorFromVector(values, ctx.device_context(), out);
...@@ -83,7 +83,7 @@ class AssignValueKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,7 @@ class AssignValueKernel : public framework::OpKernel<T> {
dtype)); dtype));
break; break;
} }
CopyVecotorToTensor<T>(value_name, out, ctx); CopyVectorToTensor<T>(value_name, out, ctx);
out->Resize(phi::make_ddim(shape)); out->Resize(phi::make_ddim(shape));
} }
}; };
......
...@@ -157,7 +157,7 @@ void SetValueCompute(const framework::ExecutionContext& ctx, ...@@ -157,7 +157,7 @@ void SetValueCompute(const framework::ExecutionContext& ctx,
value_t.mutable_data<T>(value_dims, place); value_t.mutable_data<T>(value_dims, place);
auto value_name = GetValueName(dtype); auto value_name = GetValueName(dtype);
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx); CopyVectorToTensor<T>(value_name.c_str(), &value_t, ctx);
value_t.Resize(value_dims); value_t.Resize(value_dims);
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>( ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor); ctx, &slice_tensor, &value_t, -1, SubFunctor<T>(), &slice_tensor);
......
...@@ -141,7 +141,7 @@ class SetValueNPUKernel : public framework::OpKernel<T> { ...@@ -141,7 +141,7 @@ class SetValueNPUKernel : public framework::OpKernel<T> {
value_t.mutable_data<T>(value_dims, ctx.GetPlace()); value_t.mutable_data<T>(value_dims, ctx.GetPlace());
auto value_name = auto value_name =
GetValueName(framework::TransToProtoVarType(in->dtype())); GetValueName(framework::TransToProtoVarType(in->dtype()));
CopyVecotorToTensor<T>(value_name.c_str(), &value_t, ctx); CopyVectorToTensor<T>(value_name.c_str(), &value_t, ctx);
value_t.Resize(value_dims); value_t.Resize(value_dims);
} }
......
...@@ -47,7 +47,7 @@ void AssignArrayKernel(const Context& dev_ctx, ...@@ -47,7 +47,7 @@ void AssignArrayKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
typename std::enable_if<std::is_same<T, bool>::value>::type CopyVecotorToTensor( typename std::enable_if<std::is_same<T, bool>::value>::type CopyVectorToTensor(
const Context& dev_ctx, const Context& dev_ctx,
const std::vector<Scalar>& values, const std::vector<Scalar>& values,
DenseTensor* out) { DenseTensor* out) {
...@@ -72,8 +72,8 @@ typename std::enable_if<std::is_same<T, bool>::value>::type CopyVecotorToTensor( ...@@ -72,8 +72,8 @@ typename std::enable_if<std::is_same<T, bool>::value>::type CopyVecotorToTensor(
} }
template <typename T, typename Context> template <typename T, typename Context>
typename std::enable_if<!std::is_same<T, bool>::value>::type typename std::enable_if<!std::is_same<T, bool>::value>::type CopyVectorToTensor(
CopyVecotorToTensor(const Context& dev_ctx, const Context& dev_ctx,
const std::vector<Scalar>& values, const std::vector<Scalar>& values,
DenseTensor* out) { DenseTensor* out) {
std::vector<T> assign_values; std::vector<T> assign_values;
...@@ -98,7 +98,7 @@ void AssignValueKernel(const Context& dev_ctx, ...@@ -98,7 +98,7 @@ void AssignValueKernel(const Context& dev_ctx,
"argument dtype is %s, kernel dtype is %s.", "argument dtype is %s, kernel dtype is %s.",
dtype, dtype,
template_dtype)); template_dtype));
CopyVecotorToTensor<T>(dev_ctx, values, out); CopyVectorToTensor<T>(dev_ctx, values, out);
out->Resize(phi::make_ddim(shape)); out->Resize(phi::make_ddim(shape));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册