我有两个要在ANN训练之前进行预处理的文件。每个文件的大小约为3GB,因此我决定使用Dask。输入文件的形状为(500000,410),输出文件的形状为(500000,695)。
我需要:
- 删除有错误的行(在输出数据的第一列中有“ 0”);
- 删除最后一列;
- 将数据拆分为训练,测试和验证数据集;
- 缩放数据集;
- 保存到.hdf5文件。
代码:
def preprocessing(path,random_seed):
np.random.seed(random_seed)
random.seed(random_seed)
input_file = dd.read_csv(
os.path.join(path,'input_values.csv'),header = None,sep = ';',dtype = np.float64)
output_file = dd.read_csv(
os.path.join(path,'output_values.csv'),dtype = np.float64)
# Delete errors
errors = output_file[0] == 0
errors = errors.compute().reset_index()[0]
input_data = input_file[~errors].iloc[:,:-1].to_dask_array(lengths=True)
output_data = output_file[~errors].iloc[:,:-1].to_dask_array(lengths=True)
# Split to datasets
input_train,input_test,output_train,output_test = train_test_split(
input_data,output_data,random_state = random_seed,shuffle = True,test_size = 0.2)
input_train,input_val,output_val = train_test_split(
input_train,test_size = 0.25)
# Scaling
scaler = StandardScaler()
scaler.fit(input_train)
data_scaled = {
'input_train': scaler.transform(input_train),'input_test': scaler.transform(input_test),'input_val': scaler.transform(input_val),'output_train': output_train,'output_test': output_test,'output_val': output_val
}
# Save data to hdf5
file_name = os.path.basename(os.path.normpath(path))
path_to_output = os.path.join(path,file_name) + '.hdf5'
with h5py.File(path_to_output,'w') as file:
for dataset in data_scaled.keys():
data = data_scaled[dataset]
file.create_dataset(dataset,data.shape,data = data)
metadata = {
'scaler_mean': scaler.mean_,'scaler_scale': scaler.scale_,}
file.attrs.update(metadata)
它给出了'IndexError:索引74442超出轴0尺寸13987的范围'
但是如果我将'train_test_split'更改为:
sep_train = int(round(0.6 * input_data.shape[0]))
sep_test = int(round(0.8 * input_data.shape[0]))
input_train = input_data[:sep_train,:]
input_test = input_data[sep_train:sep_test,:]
input_val = input_data[sep_test:,:]
output_train = output_data[:sep_train,:]
output_test = output_data[sep_train:sep_test,:]
output_val = output_data[sep_test:,:]
它将成功完成。
为什么会这样?以这种方式删除有错误的行是否正确?因为它抛出“ UserWarning:布尔系列键将被重新索引以匹配DataFrame索引。”在每个步骤中。