因此,最终我找到了帖子中提到的第二种方法的解决方案。使用稀疏矩阵可以避免在尝试将矩阵与大数据(类别和/或观察值)相乘时可能发生的内存问题。
我编写了此函数,该函数返回原始数据帧,并附加了所有所需的分类变量的嵌入矢量。
def get_embeddings(model: keras.models.Model,cat_vars: List[str],df: pd.DataFrame,dict: Dict[str,Dict[str,int]]) -> pd.DataFrame:
df_list: List[pd.DataFrame] = [df]
for var_name in cat_vars:
df_1vec: pd.DataFrame = df.loc[:,var_name]
enc = OneHotEncoder()
sparse_mat = enc.fit_transform(df_1vec.values.reshape(-1,1))
sparse_mat = sparse.csr_matrix(sparse_mat,dtype='uint8')
orig_dict = dict[var_name]
match_to_arr = np.empty(
(sparse_mat.shape[1],model.get_layer(f'embedding_{var_name}').get_weights()[0].shape[1]))
match_to_arr[:] = np.nan
unknown_cat = model.get_layer(f'embedding_{var_name}').get_weights()[0].shape[0] - 1
for i,col in enumerate(tqdm.tqdm(enc.categories_[0])):
if col in orig_dict.keys():
val = orig_dict[col]
match_to_arr[i,:] = model.get_layer(f'embedding_{var_name}').get_weights()[0][val,:]
else:
match_to_arr[i,:] = (model.get_layer(f'embedding_{var_name}')
.get_weights()[0][unknown_cat,:])
a = sparse_mat.dot(match_to_arr)
a = pd.DataFrame(a,columns=[f'{var_name}_{i}' for i in range(1,match_to_arr.shape[1] + 1)])
df_list.append(a)
df_final = pd.concat(df_list,axis=1)
return df_final
dict
是词典的字典,即为我预先编码的每个分类变量保存一个字典,键为类别名称和值整数。请注意,每个类别都用num_values + 1
编码,最后一个类别保留给未知类别。
基本上我在做什么是要求每个类别值是否在字典中。如果是的话,我将一个临时数组中的对应行(因此,如果这是第一类别,则是第一行)分配给嵌入矩阵中的对应行,其中行号对应于其类别名称被编码为的值。
如果不在字典中,那么我将对应未知类别的嵌入矩阵中的最后一行分配给该行(第i
行)。
,
这是我在评论中介绍的
df = pd.DataFrame({'int':np.random.uniform(0,1,10),'cat':np.random.randint(0,333,10)}) # cat are encoded
## define embedding model,you can also use multiple input source
inp = Input((1))
emb = Embedding(input_dim=10000+2,output_dim=50,name='embedding')(inp)
out = Dense(10)(emb)
model = Model(inp,out)
# model.compile(...)
# model.fit(...)
## get cat embeddings
extractor = Model(model.input,Flatten()(model.get_layer('embedding').output))
## concat embedding in the orgiginal df
df = pd.concat([df,pd.DataFrame(extractor.predict(df.cat.values))],axis=1)
df
本文链接:https://www.f2er.com/2332442.html