Transformers 加載預訓練模型

加載Google AI或OpenAI預訓練權重或PyTorch轉儲

from_pretrained()方法

要加載Google AI、OpenAI的預訓練模型或PyTorch保存的模型(用torch.save()保存的BertForPreTraining實例),PyTorch模型類和tokenizer可以被from_pretrained()實例化:

<code>model = BERT_CLASS.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None, from_tf=False, state_dict=None, *input, **kwargs)/<code>

其中

  • BERT_CLASS要麼是用於加載詞彙表的tokenizer(BertTokenizer或OpenAIGPTTokenizer類),要麼是加載八個BERT或三個OpenAI GPT PyTorch模型類之一(用於加載預訓練權重):BertModel ,BertForMaskedLM,BertForNextSentencePrediction,BertForPreTraining,BertForSequenceClassification,BertForTokenClassification,BertForMultipleChoice ,BertForQuestionAnswering,OpenAIGPTModel,OpenAIGPTLMHeadModel或OpenAIGPTDoubleHeadsModel
  • PRE_TRAINED_MODEL_NAME_OR_PATH為: Google AI或OpenAI的預定義的快捷名稱列表,其中的模型都是已經訓練好的模型: bert-base-uncased:12個層,768個隱藏節點,12個heads,110M參數量。 bert-large-uncased:24個層,1024個隱藏節點,16個heads,340M參數量。 bert-base-cased:12個層,768個隱藏節點,12個heads,110M參數量。 bert-large-cased:24個層,1024個隱藏節點,16個heads,340M參數量。 bert-base-multilingual-uncased:(原始,不推薦)12個層,768個隱藏節點,12個heads,110M參數量。 bert-base-multilingual-cased:(新的,推薦)12個層,768個隱藏節點,12個heads,110M參數量。 bert-base-chinese:簡體中文和繁體中文,12個層,768個隱藏節點,12個heads,110M參數量。 bert-base-german-cased:僅針對德語數據訓練,12個層,768個隱藏節點,12個heads,110M參數量。性能評估(https://deepset.ai/german-bert) bert-large-uncased-whole-word-masking:24個層,1024個隱藏節點,16個heads,340M參數量。經過Whole Word Masking模式訓練(該單詞對應的標記全部掩碼處理) bert-large-cased-whole-word-masking:24個層,1024個隱藏節點,16個heads,340M參數量。經過Whole Word Masking模式訓練(該單詞對應的標記全部掩碼處理) bert-large-uncased-whole-word-masking-finetuned-squad:在SQuAD上微調的bert-large-uncased-whole-word-masking模型(使用run_bert_squad.py)。結果:exact_match:86.91579943235573,f1:93.1532499015869 bert-base-german-dbmdz-cased:僅針對德語數據訓練,12個層,768個隱藏節點,12個heads,110M參數量。性能評估(https://deepset.ai/german-bert) bert-base-german-dbmdz-uncased:僅針對德語數據(無大小寫),12個層,768個隱藏節點,12個heads,110M參數量。性能評估(https://github.com/dbmdz/german-bert) openai-gpt:OpenAI GPT英文模型,12個層,768個隱藏節點,12個heads,110M參數量。 gpt2:OpenAI GPT-2英語模型,12個層,768個隱藏節點,12個heads,117M參數量。 gpt2-medium:OpenAI GPT-2英語模型,24個層,1024個隱藏節點、16個heads,345M參數量。 transfo-xl-wt103:使用Transformer-XL英語模型在wikitext的-103上訓練的模型,24個層,1024個隱藏節點、16個heads,257M參數量。 一個路徑或URL包含一個預訓練模型: bert_config.json或openai_gpt_config.json是用於模型的配置文件 pytorch_model.bin是BertForPreTraining保存的OpenAIGPTModel,TransfoXLModel 和GPT2LMHeadModel的預訓練實例的PyTorch轉儲。(使用常用的torch.save()保存) 如果PRE_TRAINED_MODEL_NAME_OR_PATH是快捷名稱,則將從AWS S3下載預訓練權重。可以參見鏈接(https://github.com/huggingface/transformers/blob/master/transformers/modeling_bert.py)並存儲在緩存文件夾中以避免以後需要下載(可以在`~/.pytorch_pretrained_bert/`中找到該緩存文件夾)。 cache_dir可以是特定目錄的可選路徑,以下載和緩存預先訓練的模型權重。該選項在使用分佈式訓練時特別有用:為避免同時訪問相同的權重,你可以設置例如cache_dir='./pretrained_model_{}'.format(args.local_rank)。)。 from_tf :我們應該從本地保存的TensorFlow checkpoint加載權重 state_dict :可選狀態字典(collections.OrderedDict對象),而不是使用Google的預訓練模式 *inputs,** kwargs:特定Bert類的附加輸入(例如:BertForSequenceClassification的num_labels)

Uncased表示在WordPiece標記化之前,文本已小寫,例如,John Smith變為john smith。Uncased模型還會刪除任何重音標記。Cased表示保留了真實的大小寫和重音標記。通常,除非你知道案例信息對於你的任務很重要(例如,命名實體識別或詞性標記),否則Uncased模型會更好。有關多語言和中文模型的信息,請參見(https://github.com/google-research/bert/blob/master/multilingual.md)或原始的TensorFlow存儲庫。

當使用Uncased的模型時,請確保將--do_lower_case傳遞給示例訓練腳本(如果使用自己的腳本,則將do_lower_case=True傳遞給FullTokenizer))。

示例:

<code># BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, do_basic_tokenize=True)
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# OpenAI GPT
tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
model = OpenAIGPTModel.from_pretrained('openai-gpt')

# Transformer-XL
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')

# OpenAI GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')/<code>

緩存目錄

pytorch_pretrained_bert將預訓練權重保存在緩存目錄中(位於此優先級):

  • cache_dir是from_pretrained()方法的可選參數(見上文),
  • shell環境變量PYTORCH_PRETRAINED_BERT_CACHE,
  • PyTorch緩存目錄+/pytorch_pretrained_bert/ ,其中PyTorch緩存目錄由(按此順序定義): 外殼環境變量ENV_TORCH_HOME shell環境變量ENV_XDG_CACHE_HOME +/torch/) 默認值:~/.cache/torch/

通常,如果你未設置任何特定的環境變量pytorch_pretrained_bert緩存將位於~/.cache/torch/pytorch_pretrained_bert/中。

你可以始終安全地刪除pytorch_pretrained_bert緩存,但是必須從我們的S3重新下載預訓練模型權重和詞彙文件。


分享到:


相關文章: