{
"cells": [
{
"cell_type": "markdown",
"id": "b416b922",
"metadata": {},
"source": [
"\n",
"\n",
"## 一、训练\n",
"\n",
"### 1.1 导入mda数据\n",
"\n",
"读取2001-2022年的**管理层讨论与分析mda**数据"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c8bc5c7f",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_excel('mda01-22.xlsx')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"id": "c2e95324",
"metadata": {},
"source": [
"
\n",
"### 1.2 构造语料"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e0f17c1b",
"metadata": {},
"outputs": [],
"source": [
"with open('mda01-22.txt', 'a+', encoding='utf-8') as f:\n",
" text = ''.join(df['text'])\n",
" f.write(text)"
]
},
{
"cell_type": "markdown",
"id": "2da82ea6",
"metadata": {},
"source": [
"
\n",
"\n",
"### 1.3 配置cntext环境\n",
"\n",
"使用2.0.0版本cntext库(该版本暂不开源,需付费购买)。 将得到的 **cntext-2.0.0-py3-none-any.whl** 文件放置于电脑桌面, win系统打开**cmd**(Mac打开terminal), 输入如下命令(将工作环境切换至桌面)\n",
"\n",
"```\n",
"cd desktop\n",
"```\n",
"
\n",
"\n",
"\n",
"个别Win用户如无效,试试``cd Desktop`` 。\n",
"\n",
"\n",
"\n",
"继续在cmd (terminal) 中执行如下命令安装cntext2.0.0\n",
"\n",
"```\n",
"pip3 install cntext-2.0.0-py3-none-any.whl \n",
"```\n",
"\n",
"
\n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "ce1f72b8",
"metadata": {},
"source": [
"
\n",
"\n",
"### 1.4 训练word2vec\n",
"\n",
"设置模型参数配置\n",
"\n",
"- mda01-22 使用2001-2022年度mda数据训练\n",
"- 200 嵌入的维度数,即每个词的向量长度是200\n",
"- 6 词语上下文的窗口是6\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3754e9b",
"metadata": {},
"outputs": [],
"source": [
"\n",
"%%time #程序结束后,可查看总的运行时间\n",
"import cntext as ct\n",
"\n",
"w2v = ct.W2VModel(corpus_file='mda01-22.txt')\n",
"w2v.train(vector_size=200, window_size=6, min_count=6, save_dir='Word2Vec')\n"
]
},
{
"cell_type": "markdown",
"id": "cb58ded0",
"metadata": {},
"source": [
"经过不到两个小时时间, 训练出的中国A股市场词向量模型,词汇量 914058, 模型文件 1.49G。模型可广泛用于经济管理等领域概念(情感)词典的构建或扩展。 \n",
"\n",
"- **mda01-22.200.6.bin**\n",
"- **mda01-22.200.6.bin.syn1neg.npy**\n",
"- **mda01-22.200.6.bin.wv.vectors.npy**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "69f3c2d8",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "54204078",
"metadata": {},
"source": [
"\n",
"## 二、导入模型\n",
"需要用到两个自定义函数load_w2v、expand_dictionary,源代码太长,为了提高阅读体验, 放在文末。大家记得用这两个函数前一定要先导入"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "db78149a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading word2vec model...\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#先导入load_w2v、expand_dictionary函数源代码\n",
"\n",
"\n",
"#读取模型文件\n",
"w2v_model = load_w2v(w2v_path='mda01-22.200.6.bin')\n",
"w2v_model"
]
},
{
"cell_type": "markdown",
"id": "bd0971bd",
"metadata": {},
"source": [
"### 三、w2v的使用\n",
"- 查看词汇量\n",
"- 查询某词向量\n",
"- 查看多个词的均值向量\n",
"\n",
"更多内容,建议查看下gensim库的文档"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9ec6fd4b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"914058"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#词汇量\n",
"len(w2v_model.wv.index_to_key)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "08694746",
"metadata": {
"collapsed": true
},
"outputs": [
{
"data": {
"text/plain": [
"array([-1.36441350e-01, -2.02002168e+00, -1.49168205e+00, 2.65202689e+00,\n",
" 1.49721682e+00, 2.14851022e+00, -1.54925853e-01, -2.25241160e+00,\n",
" -3.58773202e-01, 1.54530525e+00, -7.62950361e-01, -9.77181852e-01,\n",
" 6.70365512e-01, -3.20203233e+00, 3.18079638e+00, 1.66510820e+00,\n",
" 9.80131567e-01, 1.62199986e+00, 1.80585206e+00, 4.08179426e+00,\n",
" -1.26518166e+00, 3.75929743e-01, 5.72038591e-01, 1.16134119e+00,\n",
" 2.55617023e+00, -2.25110960e+00, -2.61538339e+00, -5.71992218e-01,\n",
" 8.70356798e-01, -1.85045290e+00, -2.85597444e-01, -9.15628672e-01,\n",
" -2.03667688e+00, 2.11716801e-01, 2.94088912e+00, -2.32688546e+00,\n",
" 2.20858502e+00, 8.81347775e-01, -7.99135566e-01, -8.61206651e-01,\n",
" -4.45446587e+00, -1.73757005e+00, -3.36678886e+00, -2.82611530e-02,\n",
" -1.62726247e+00, -8.49750221e-01, 4.13731128e-01, -1.62519825e+00,\n",
" 3.03865957e+00, -1.39746085e-01, 8.22233260e-01, -7.97697455e-02,\n",
" 1.72468078e+00, 2.94929433e+00, 9.72453177e-01, -1.12741642e-01,\n",
" 8.18425417e-01, -9.05264139e-01, 2.61516261e+00, 8.02830994e-01,\n",
" 2.40420485e+00, 8.85799348e-01, -1.08665645e+00, 8.21912348e-01,\n",
" -4.39456075e-01, -2.57663131e+00, 2.38062453e+00, -4.58515882e-01,\n",
" 2.12767506e+00, -2.01356173e-01, 2.71096081e-01, 9.51708496e-01,\n",
" -3.05705309e+00, -6.06385887e-01, -1.38406023e-01, 2.36809158e+00,\n",
" -2.49158549e+00, 2.71105647e+00, -3.07211792e-03, 1.04273570e+00,\n",
" 1.44201803e+00, -5.65704823e-01, 2.85488725e-01, 1.43495277e-01,\n",
" -1.39421299e-01, 9.24086392e-01, 4.25374925e-01, -1.56690669e+00,\n",
" 1.67641795e+00, -1.03729677e+00, -1.45472065e-01, -2.11022258e+00,\n",
" -1.81541741e+00, -8.66766050e-02, 8.72350857e-02, 1.17173791e+00,\n",
" -3.07721123e-02, 5.84330797e-01, 1.47265148e+00, -1.76913440e+00,\n",
" -8.48391712e-01, -3.25056529e+00, 7.14846313e-01, -2.98076987e-01,\n",
" 1.13966620e+00, -1.42698896e+00, 6.93505168e-01, -2.04717040e+00,\n",
" -1.53559577e+00, 1.01942134e+00, -1.58283603e+00, 9.08654630e-01,\n",
" -1.90529859e+00, -9.43309963e-01, 4.12964225e-01, -2.50713086e+00,\n",
" -4.24056143e-01, -4.10613680e+00, 3.60615468e+00, -4.19765860e-01,\n",
" -2.41174579e+00, 6.80675328e-01, 2.99834704e+00, 1.05610855e-01,\n",
" -7.84325838e-01, 3.24065971e+00, -1.85072863e+00, -2.12448812e+00,\n",
" -2.83468294e+00, -5.77759802e-01, -3.13433480e+00, -6.91670418e-01,\n",
" 2.99401569e+00, -5.16145706e-01, 9.09552336e-01, -5.52680910e-01,\n",
" -2.88360894e-01, 1.11991334e+00, -1.11737549e+00, 1.15479147e+00,\n",
" -4.63319182e-01, 1.38351321e+00, -3.02179503e+00, 1.24334955e+00,\n",
" 1.93393975e-01, -8.27962995e-01, -2.37227559e+00, -9.26931739e-01,\n",
" 6.72517180e-01, 1.27736795e+00, 1.98695862e+00, 1.41960573e+00,\n",
" -3.73892736e+00, -3.14201683e-01, -7.19093859e-01, 1.86080355e-02,\n",
" -2.68105698e+00, 1.04344964e+00, 9.46133554e-01, -2.06151366e+00,\n",
" -2.84214950e+00, 1.17004764e+00, 1.24577022e+00, -1.10806060e+00,\n",
" 9.93207514e-01, 8.46789181e-01, -3.09691691e+00, 2.12616014e+00,\n",
" -1.49274826e+00, -1.53214395e+00, -9.95470941e-01, 1.23463202e+00,\n",
" -2.18907285e+00, -4.94913310e-01, 2.80939412e+00, 1.68149090e+00,\n",
" 1.48991072e+00, 3.83729649e+00, 4.72325265e-01, 1.37606680e+00,\n",
" 2.14257884e+00, 3.18186909e-01, 5.98093605e+00, 1.46744043e-01,\n",
" -2.37729326e-01, 1.20463884e+00, -1.55812174e-01, -5.03088772e-01,\n",
" 4.53981996e-01, 1.95544350e+00, -2.32564354e+00, -4.09389853e-01,\n",
" 1.89125270e-01, 2.62835431e+00, 9.81123984e-01, -9.51041043e-01,\n",
" -1.14294410e-01, 1.10983588e-01, 9.30419266e-02, -9.84693542e-02],\n",
" dtype=float32)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#查询某词的词向量\n",
"w2v_model.wv.get_vector('创新')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b6817432",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.03019853, -0.01928307, -0.05371316, 0.00053774, 0.02516318,\n",
" 0.10103251, -0.03914721, -0.08307559, 0.00444389, 0.09456791,\n",
" -0.05761364, -0.03459097, 0.04394419, -0.10181106, 0.1418381 ,\n",
" 0.05334964, 0.01820264, 0.01493831, 0.01626587, 0.17402864,\n",
" -0.02859601, 0.04538149, 0.03768233, 0.05431981, 0.15405464,\n",
" -0.03632693, -0.08566202, -0.00595666, 0.08378439, -0.11071078,\n",
" -0.05904576, -0.06451955, -0.1076955 , 0.05141645, 0.11710279,\n",
" -0.09403889, 0.08633652, -0.06743232, 0.00328483, 0.01589498,\n",
" -0.11226317, -0.05367877, -0.057222 , -0.00685401, -0.04531868,\n",
" -0.02090884, 0.01426806, -0.04787309, 0.1325518 , -0.00498158,\n",
" 0.01912023, -0.02292867, 0.08855374, 0.07697155, 0.01407153,\n",
" -0.02378988, 0.03745927, 0.00889686, 0.12555045, 0.04007044,\n",
" 0.06247196, 0.04912657, -0.06158784, 0.06346396, 0.00197599,\n",
" -0.04995281, 0.05125345, -0.01584197, 0.07572784, 0.02580263,\n",
" -0.02904062, -0.0008835 , -0.08365948, -0.05539802, -0.07523517,\n",
" 0.04622741, -0.12007375, 0.05453204, -0.02054051, 0.02937108,\n",
" 0.10272598, -0.0089594 , 0.05172383, 0.00588922, -0.0010917 ,\n",
" 0.02603476, -0.01580217, -0.07810815, 0.06964722, -0.04709972,\n",
" -0.0316673 , -0.05055645, -0.05096703, 0.02772727, -0.03495743,\n",
" 0.09567484, -0.0071935 , -0.01266821, 0.00074132, -0.07593331,\n",
" -0.02928162, -0.12574387, 0.02437552, -0.0228716 , -0.03047204,\n",
" -0.03948782, 0.07722469, -0.07440004, -0.00951135, 0.05531401,\n",
" -0.03240326, 0.00389662, -0.05632257, -0.05030375, 0.02883579,\n",
" -0.06157173, 0.00584065, -0.16594191, 0.1108149 , -0.00243916,\n",
" -0.09964953, 0.02029083, 0.03522225, -0.01167114, -0.04048527,\n",
" 0.08301719, -0.04682562, -0.0714631 , -0.07355815, -0.0496731 ,\n",
" -0.05303175, -0.03625978, 0.06879813, -0.09117774, 0.0323513 ,\n",
" -0.01808765, -0.01746182, 0.02472609, -0.00873791, -0.00951474,\n",
" -0.02176155, 0.02394484, -0.07035318, 0.10963078, 0.01004294,\n",
" -0.02269555, -0.09929934, -0.02897175, 0.02157164, 0.05608977,\n",
" 0.09083252, -0.00525982, -0.09866816, -0.02736895, -0.02923711,\n",
" 0.05582205, -0.04462272, 0.01932517, 0.04468061, 0.00317996,\n",
" -0.04182415, 0.03061792, 0.04278665, 0.02939183, 0.03475334,\n",
" -0.00898206, -0.08902986, 0.08294971, -0.00942507, -0.02125597,\n",
" -0.01008157, 0.04477865, -0.08366893, -0.00074587, 0.08328778,\n",
" 0.02653155, 0.04581301, 0.10532658, -0.04637942, 0.04722971,\n",
" 0.06853952, -0.00235328, 0.18312256, -0.0457427 , 0.00874868,\n",
" 0.08945092, -0.01135547, -0.04203002, 0.02408407, 0.0594779 ,\n",
" -0.05467811, 0.01946783, 0.07095537, 0.04226222, -0.0018304 ,\n",
" -0.00086302, 0.04624099, 0.01009499, 0.04783599, 0.02535392],\n",
" dtype=float32)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#查询多个词的词向量\n",
"w2v_model.wv.get_mean_vector(['创新', '研发'])"
]
},
{
"cell_type": "markdown",
"id": "7b9483e3",
"metadata": {},
"source": [
"有了每个词或者概念的向量,可以结合cntext旧版本单语言模型内的态度偏见的度量。\n",
"\n",
"
\n",
"\n",
"## 四、扩展词典\n",
"做词典法的文本分析,最重要的是有自己的领域词典。之前受限于技术难度,文科生的我也一直在用形容词的通用情感词典。现在依托word2vec技术, 可以加速人工构建的准确率和效率。\n",
"\n",
"\n",
"下面是在 mda01-22.200.6.bin 上做的词典扩展测试,函数expand_dictionary会根据种子词选取最准确的topn个词。"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "79ab7d3c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['抓紧',\n",
" '立刻',\n",
" '月底',\n",
" '年底',\n",
" '年终',\n",
" '争取',\n",
" '力争',\n",
" '争取',\n",
" '力争',\n",
" '年底',\n",
" '月底',\n",
" '3月底',\n",
" '尽快',\n",
" '上半年',\n",
" '努力争取',\n",
" '年内实现',\n",
" '抓紧',\n",
" '工作争取',\n",
" '尽早',\n",
" '6月底',\n",
" '工作力争',\n",
" '7月份',\n",
" '年底完成',\n",
" '确保',\n",
" '早日',\n",
" '有望',\n",
" '全力',\n",
" '创造条件',\n",
" '3月份',\n",
" '加紧',\n",
" '力争实现',\n",
" '力争今年',\n",
" '月底前',\n",
" '10月底',\n",
" '4月份',\n",
" '继续',\n",
" '月初']"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#短视主义词 实验\n",
"expand_dictionary(wv=w2v_model.wv, \n",
" seedwords=['抓紧', '立刻', '月底', '年底', '年终', '争取', '力争'],\n",
" topn=30)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "962d5dbb",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['团结',\n",
" '拼搏',\n",
" '克服',\n",
" '勇攀高峰',\n",
" '友善',\n",
" '进取',\n",
" '拼搏',\n",
" '艰苦奋斗',\n",
" '团结拼搏',\n",
" '勇于担当',\n",
" '锐意进取',\n",
" '勇气',\n",
" '团结',\n",
" '团结奋进',\n",
" '团结一致',\n",
" '顽强拼搏',\n",
" '上下一心',\n",
" '实干',\n",
" '拼搏进取',\n",
" '积极进取',\n",
" '奋力拼搏',\n",
" '奋进',\n",
" '坚定信念',\n",
" '团结一心',\n",
" '精诚团结',\n",
" '顽强',\n",
" '踏实',\n",
" '团结协作',\n",
" '求真务实',\n",
" '团结奋斗',\n",
" '奋发有为',\n",
" '同心协力',\n",
" '脚踏实地',\n",
" '开拓进取',\n",
" '进取',\n",
" '勇于']"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expand_dictionary(wv=w2v_model.wv, \n",
" seedwords=['团结', '拼搏', '克服', '勇攀高峰', '友善', '进取'],\n",
" topn=30)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1d9b3f87",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['创新',\n",
" '科技',\n",
" '研发',\n",
" '技术',\n",
" '标准',\n",
" '技术创新',\n",
" '技术研发',\n",
" '先进技术',\n",
" '关键技术',\n",
" '创新性',\n",
" '前沿技术',\n",
" '科技创新',\n",
" '技术应用',\n",
" '产品开发',\n",
" '自主创新',\n",
" '新技术',\n",
" '科研',\n",
" '产品研发',\n",
" '自主研发',\n",
" '技术开发',\n",
" '工艺技术',\n",
" '技术标准',\n",
" '基础研究',\n",
" '集成创新',\n",
" '核心技术',\n",
" '成熟技术',\n",
" '研发创新',\n",
" '理论技术',\n",
" '前沿技术研发',\n",
" '工艺',\n",
" '科技成果',\n",
" '技术研究',\n",
" '标准制定',\n",
" '技术装备',\n",
" '技术相结合']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expand_dictionary(wv=w2v_model.wv, \n",
" seedwords=['创新', '科技', '研发', '技术', '标准'],\n",
" topn=30)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8ad2605a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['竞争',\n",
" '竞争力',\n",
" '竞争能力',\n",
" '市场竞争',\n",
" '竞争优势',\n",
" '市场竞争力',\n",
" '竞',\n",
" '竞争实力',\n",
" '激烈竞争',\n",
" '参与市场竞争',\n",
" '国际竞争',\n",
" '市场竞争能力',\n",
" '竞争态势',\n",
" '市场竞争优势',\n",
" '行业竞争',\n",
" '综合竞争力',\n",
" '竞争对手',\n",
" '未来市场竞争',\n",
" '产品竞争力',\n",
" '之间竞争',\n",
" '核心竞争力',\n",
" '参与竞争',\n",
" '核心竞争能力',\n",
" '竞争日趋激烈',\n",
" '国际化竞争',\n",
" '国际竞争力',\n",
" '竟争力',\n",
" '市场化竞争',\n",
" '同质化竞争',\n",
" '竞争力关键',\n",
" '价格竞争',\n",
" '整体竞争力']"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expand_dictionary(wv=w2v_model.wv, \n",
" seedwords=['竞争', '竞争力'],\n",
" topn=30)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "368a6631",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['疫情',\n",
" '扩散',\n",
" '防控',\n",
" '反复',\n",
" '冲击',\n",
" '蔓延',\n",
" '疫情',\n",
" '疫情爆发',\n",
" '疫情冲击',\n",
" '新冠疫情',\n",
" '肆虐',\n",
" '新冠肺炎',\n",
" '疫情蔓延',\n",
" '本次疫情',\n",
" '散发',\n",
" '疫情扩散',\n",
" '疫情影响',\n",
" '疫情反复',\n",
" '疫情传播',\n",
" '肺炎疫情',\n",
" '国内疫情',\n",
" '击',\n",
" '各地疫情',\n",
" '疫情全球',\n",
" '疫情多点',\n",
" '全球疫情',\n",
" '持续蔓延',\n",
" '多点散发',\n",
" '疫情导致',\n",
" '疫情暴发',\n",
" '病毒疫情',\n",
" '疫情持续',\n",
" '疫情初期',\n",
" '疫情出现',\n",
" '防控措施']"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expand_dictionary(wv=w2v_model.wv, \n",
" seedwords=['疫情', '扩散', '防控', '反复', '冲击'],\n",
" topn=30)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "03dc896b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['旧',\n",
" '老',\n",
" '后',\n",
" '落后',\n",
" '老',\n",
" '旧',\n",
" '陈旧',\n",
" '老旧',\n",
" '淘汰',\n",
" '低效率',\n",
" '低效',\n",
" '部分老旧',\n",
" '进行改造',\n",
" '老旧设备',\n",
" '工艺落后',\n",
" '设备陈旧',\n",
" '能耗高',\n",
" '更新改造',\n",
" '落后工艺',\n",
" '技术落后',\n",
" '改造',\n",
" '翻新',\n",
" '简陋',\n",
" '旧设备',\n",
" '拆除',\n",
" '现象严重',\n",
" '原有',\n",
" '相对落后',\n",
" '产能淘汰',\n",
" '加快淘汰',\n",
" '搬',\n",
" '替换',\n",
" '大批',\n",
" '迁']"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expand_dictionary(wv=w2v_model.wv, \n",
" seedwords=['旧', '老', '后', '落后'],\n",
" topn=30)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e96f77ed",
"metadata": {},
"outputs": [],
"source": [
"from gensim.models import KeyedVectors\n",
"from pathlib import Path\n",
"\n",
"\n",
"def load_w2v(w2v_path):\n",
" \"\"\"\n",
" Load word2vec model\n",
"\n",
" Args:\n",
" w2v_path (str): path of word2vec model\n",
"\n",
" Returns:\n",
" model: word2vec model\n",
" \"\"\"\n",
" print('Loading word2vec model...')\n",
" model = KeyedVectors.load(w2v_path)\n",
" return model\n",
"\n",
"\n",
"def expand_dictionary(wv, seedwords, topn=100):\n",
" \"\"\"\n",
" According to the seed word file, select the top n words with the most similar semantics and save them in the directory save_dir.\n",
" \n",
" Args:\n",
" wv (Word2VecKeyedVectors): the word embedding model\n",
" seedwords (list): 种子词\n",
" topn (int, optional): Set the number of most similar words to retrieve to topn. Defaults to 100.\n",
" save_dir (str, optional): the directory to save the candidate words. Defaults to 'Word2Vec'.\n",
" \n",
" Returns:\n",
" \"\"\"\n",
" simidx_scores = []\n",
"\n",
" similars_candidate_idxs = [] #the candidate words of seedwords\n",
" dictionary = wv.key_to_index\n",
" seedidxs = [] #transform word to index\n",
" for seed in seedwords:\n",
" if seed in dictionary:\n",
" seedidx = dictionary[seed]\n",
" seedidxs.append(seedidx)\n",
" for seedidx in seedidxs:\n",
" # sims_words such as [('by', 0.99984), ('or', 0.99982), ('an', 0.99981), ('up', 0.99980)]\n",
" sims_words = wv.similar_by_word(seedidx, topn=topn)\n",
" #Convert words to index and store them\n",
" similars_candidate_idxs.extend([dictionary[sim[0]] for sim in sims_words])\n",
" similars_candidate_idxs = set(similars_candidate_idxs)\n",
" \n",
" for idx in similars_candidate_idxs:\n",
" score = wv.n_similarity([idx], seedidxs)\n",
" simidx_scores.append((idx, score))\n",
" simidxs = [w[0] for w in sorted(simidx_scores, key=lambda k:k[1], reverse=True)]\n",
"\n",
" simwords = [str(wv.index_to_key[idx]) for idx in simidxs][:topn]\n",
"\n",
" resultwords = []\n",
" resultwords.extend(seedwords)\n",
" resultwords.extend(simwords)\n",
" \n",
" return resultwords"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e78a807",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}