从CSV文件中初始化Tensor张量类
CSV(逗号分隔值)文件是一种特殊的文件类型,可在 Excel 中创建或编辑。
CSV文件采用逗号分隔的形式来存储文本和数字信息,总体来说,这种形式的文件格式具有扩展性好,移植性强的特点。
作用:
- 对比推理结果
- 把pytorch的结果输出到csv文件中,KuiperInfer读取,然后再对比
- 导入模型权值
- 把pytorch的模型权值输出到csv文件中,KuiperInfer读取,进行推理
接口定义:
1 2 3 4 5 6 7 8 9 10
| class CSVDataLoader { public: static std::shared_ptr<Tensor<float >> LoadData(const std::string &file_path, char split_char = ',');
static std::shared_ptr<Tensor<float >> LoadDataWithHeader(const std::string &file_path, std::vector<std::string> &headers, char split_char = ',');
private: static std::pair<size_t, size_t> GetMatrixSize(std::ifstream &file, char split_char); };
|
实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
std::pair<size_t, size_t> CSVDataLoader::GetMatrixSize(std::ifstream &file, char split_char) { bool load_ok = file.good(); # 检查是否没有发生错误 file.clear(); size_t fn_rows = 0; size_t fn_cols = 0; const std::ifstream::pos_type start_pos = file.tellg(); # 返回输入位置指示器
std::string token; std::string line_str; std::stringstream line_stream;
while (file.good() && load_ok) { std::getline(file, line_str); # 读取一行 if (line_str.empty()) { break; }
line_stream.clear(); # 修改状态标志 line_stream.str(line_str); size_t line_cols = 0;
std::string row_token; while (line_stream.good()) { std::getline(line_stream, row_token, split_char); ++line_cols; } if (line_cols > fn_cols) { fn_cols = line_cols; }
++fn_rows; } file.clear(); # 修改状态标志 file.seekg(start_pos); # 设置输入位置指示器 return {fn_rows, fn_cols}; }
std::shared_ptr<Tensor<float >> CSVDataLoader::LoadData(const std::string &file_path, char split_char) { CHECK(!file_path.empty()) << "File path is empty!"; std::ifstream in(file_path); CHECK(in.is_open() && in.good()) << "File open failed! " << file_path;
std::string line_str; std::stringstream line_stream;
const auto &[rows, cols] = CSVDataLoader::GetMatrixSize(in, split_char); std::shared_ptr<Tensor<float>> input_tensor = std::make_shared<Tensor<float>>(1, rows, cols); arma::fmat &data = input_tensor->at(0);
size_t row = 0; while (in.good()) { std::getline(in, line_str); if (line_str.empty()) { break; }
std::string token; line_stream.clear(); line_stream.str(line_str);
size_t col = 0; while (line_stream.good()) { std::getline(line_stream, token, split_char); try { data.at(row, col) = std::stof(token); } catch (std::exception &e) { LOG(ERROR) << "Parse CSV File meet error: " << e.what(); continue; } col += 1; CHECK(col <= cols) << "There are excessive elements on the column"; }
row += 1; CHECK(row <= rows) << "There are excessive elements on the row"; } return input_tensor; }
std::shared_ptr<Tensor<float>> CSVDataLoader::LoadDataWithHeader(const std::string &file_path, std::vector<std::string> &headers, char split_char) { CHECK(!file_path.empty()) << "File path is empty!"; std::ifstream in(file_path); CHECK(in.is_open() && in.good()) << "File open failed! " << file_path;
std::string line_str; std::stringstream line_stream;
const auto &[rows, cols] = CSVDataLoader::GetMatrixSize(in, split_char); CHECK(rows >= 1); std::shared_ptr<Tensor<float>> input_tensor = std::make_shared<Tensor<float>>(1, rows - 1, cols); arma::fmat &data = input_tensor->at(0);
size_t row = 0; while (in.good()) { std::getline(in, line_str); if (line_str.empty()) { break; }
std::string token; line_stream.clear(); line_stream.str(line_str);
size_t col = 0; while (line_stream.good()) { std::getline(line_stream, token, split_char); try { if(row == 0) headers.push_back(token); else data.at(row-1, col) = std::stof(token); } catch (std::exception &e) { LOG(ERROR) << "Parse CSV File meet error: " << e.what(); continue; } col += 1; CHECK(col <= cols) << "There are excessive elements on the column"; }
row += 1; CHECK(row <= rows) << "There are excessive elements on the row"; } return input_tensor; }
|
使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| TEST(test_data_load, load_csv1) { using namespace kuiper_infer;
const std::string &file_path = "../tmp/data1.csv"; std::shared_ptr<Tensor<float>> data = CSVDataLoader::LoadData(file_path, ','); uint32_t index = 1; uint32_t rows = data->rows(); uint32_t cols = data->cols(); ASSERT_EQ(rows, 3); ASSERT_EQ(cols, 6); for (uint32_t r = 0; r < rows; ++r) { for (uint32_t c = 0; c < cols; ++c) { ASSERT_EQ(data->at(0, r, c), index); index += 1; } } }
TEST(test_data_load, load_csv_with_head1) { using namespace kuiper_infer; const std::string &file_path = "../tmp/data2.csv"; std::vector<std::string> headers; std::shared_ptr<Tensor<float>> data = CSVDataLoader::LoadDataWithHeader(file_path, headers, ',');
uint32_t index = 1; uint32_t rows = data->rows(); uint32_t cols = data->cols(); LOG(INFO) << "\n" << data; ASSERT_EQ(rows, 3); ASSERT_EQ(cols, 3); ASSERT_EQ(headers.size(), 3);
ASSERT_EQ(headers.at(0), "ROW1"); ASSERT_EQ(headers.at(1), "ROW2"); ASSERT_EQ(headers.at(2), "ROW3");
for (uint32_t r = 0; r < rows; ++r) { for (uint32_t c = 0; c < cols; ++c) { ASSERT_EQ(data->at(0, r, c), index); index += 1; } }
}
|