默认情况下,CsvReader自动匹配csv文件中的数据类型:

1
let df = CsvReader::from_path(filecsv)?.has_header(true).finish()?; Ok(df)

当csv文件中的某列出现前导0时,默认的读取会将这一列的类型设置为整数类型,且抹除前导0,这并不是我们想要的,因为前导0的列往往需要被解析为字符str类型。

有两种解决方案。

部分转换

将一些(不是全部)字段转换为 str 类型,这样做的前提是你事先知道要指定将哪些列转换。

项目的Cargo.toml依赖如下:

1
2
3
[dependencies]
polars = { version = "0.33.2", features = ["lazy","ndarray"] }
smartstring = "1.0.1"

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
use polars::datatypes::DataType::Utf8;
use polars::prelude::*;
use smartstring::SmartString;
use std::sync::Arc;
fn main() {
let mut schema = Schema::new();
schema.with_column(SmartString::from("some_columns"), Utf8);
let df_csv = CsvReader::from_path("some_input.csv")
.unwrap()
.infer_schema(None)
.has_header(true)
.with_dtypes(Some(Arc::new(schema)))
.finish()
.unwrap();
println!("{}", df_csv);
}

通过with_column设置schema,传入需要自定义的列和类型,这里我需要的是将其转换为polars::datatypes::DataType::Utf8类型。最后通过with_dtypes即可应用schema

全部转换

在一些情况下,部分转换并不方便,比如:

  • 事先不知道csv中有哪些字段
  • 字段数量很多,使用with_column不方便
  • 想将所有列(无论它原本是什么类型)都读取为str

经过进一步的实验,我找到了另一种方法。这需要使用 csv 库。

项目的Cargo.toml依赖如下:

1
2
3
[dependencies]
polars = { version = "0.33.2", features = ["lazy","ndarray"] }
csv = "1.3.0"

代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
use polars::datatypes::DataType::Utf8;
use polars::prelude::*;
use std::sync::Arc;

fn main() {

let mut rdr = csv::Reader::from_path("some_input.csv").unwrap();

let column_names = rdr.headers().unwrap().iter().map(|item| Field::new(item,Utf8));
let schema = Schema::from_iter(column_names);
let df_csv = CsvReader::from_path("some_input.csv")
.unwrap()
.infer_schema(None)
.has_header(true)
.with_dtypes(Some(Arc::new(schema)))
.finish()
.unwrap();
println!("{}", df_csv);
}

原理很简单,我们事先读取csv header确定要转换的所有列名,然后使用 rdr.headers() 读取 CSV headers 并将它们映射到迭代器中。然后,使用 from_iter 创建schema,其余的操作与之前相同。

缺点

这种方法的缺点是需要读取文件两次,效率可能会很低。但很遗憾,我目前还没有找到合适的 API 来避免额外的开销。

目前它至少满足了需求,但也许有更好的解决方案。

更优的方案

经过尝试后,我发现其实并没有这么麻烦。只需要使用 infer_schema(Some(0)) 。这将使用 0 行来推断架构,这会将所有列都被“推断”为字符串(默认、最通用的类型)。

对于 LazyCsvReader ,相应的方法是 with_infer_schema_length