Simple Transformers 库基于 HuggingFace 的 Transformers 库,可让您快速训练和评估 Transformer 模型, 初始化训练评估模型只需要 3 行代码。

安装

pip3 install simpletransformers

Simple Transformer 模型在构建时考虑了特定的自然语言处理 (NLP) 任务。 每个这样的模型都配备了旨在最适合它们打算执行的任务的特性和功能。 使用 Simple Transformers 模型的高级过程遵循相同的模式。

  1. 初始化一个特定于任务的模型 2.用train_model()训练模型
  2. 使用 eval_model() 评估模型
  3. 使用 predict() 对(未标记的)数据进行预测

但是,不同模型之间存在必要的差异,以确保它们非常适合其预期任务。 关键差异通常是输入/输出数据格式和任何任务特定功能/配置选项的差异。 这些都可以在每个任务的文档部分中找到。

当前实现的特定于任务的“Simple Transformer”模型及其任务如下所示。

Task Model
Binary and multi-class text classification文本二分类、多分类 ClassificationModel
Conversational AI (chatbot training)对话机器人训练 ConvAIModel
Language generation语言生成 LanguageGenerationModel
Language model training/fine-tuning语言模型训练、微调 LanguageModelingModel
Multi-label text classification多类别文本分类 MultiLabelClassificationModel
Multi-modal classification (text and image data combined)多模态分类 MultiModalClassificationModel
Named entity recognition命名实体识别 NERModel
Question answering问答 QuestionAnsweringModel
Regression回归 ClassificationModel
Sentence-pair classification句对分类 ClassificationModel
Text Representation Generation文本表征生成 RepresentationModel
Document Retrieval文档抽取 RetrievalModel
  • 有关如何使用这些模型的更多信息,请参阅 docs 中的相关部分。
  • 示例脚本可以在 examples 目录中找到。
  • 有关项目的最新更改,请参阅 Changelog

生成句子嵌入

使用huggingface网站https://huggingface.co/ 提供的模型

  • 英文模型 bert-base-uncased
  • 中文模型 bert-base-chinese
from simpletransformers.language_representation import RepresentationModel
sentences = ["Machine Learning and Deep Learning are part of AI", 
             "Data Science will excel in future"] #it should always be a list


model = RepresentationModel(
        model_type="bert",
        model_name="bert-base-uncased", #英文模型
        use_cuda=False)

sentence_vectors = model.encode_sentences(sentences, combine_strategy="mean")

print(sentence_vectors.shape)
print(sentence_vectors)

Run

(2, 768)

array([[-0.10800573,  0.19615649, -0.10756102, ..., -0.26362818,
         0.56403756, -0.30985302],
       [ 0.0201617 , -0.19381572,  0.4360792 , ..., -0.2979438 ,
         0.04984972, -0.702381  ]], dtype=float32)


广而告之