【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘

时间:2022-10-17 11:23:32

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


作者简介:在校大学生一枚,华为云享专家,阿里云专家博主,腾云先锋(TDP)成员,云曦智划项目总负责人,全国高等学校计算机教学与产业实践资源建设专家委员会(TIPCC)志愿者,以及编程爱好者,期待和大家一起学习,一起进步~ . 博客主页ぃ灵彧が的学习日志 . 本文专栏人工智能 . 专栏寄语:若你决定灿烂,山无遮,海无拦 . 【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘

(文章目录)


前言

(一)、任务描述

ERNIE-GEN 是面向生成任务的预训练-微调框架,首次在预训练阶段加入span-by-span 生成任务,让模型每次能够生成一个语义完整的片段。在预训练和微调中通过填充式生成机制噪声感知机制来缓解曝光偏差问题。此外, ERNIE-GEN 采用多片段-多粒度目标文本采样策略, 增强源文本和目标文本的关联性,加强了编码器和解码器的交互。得益于以上策略,ERNIE-GEN在多个生成任务中创造了最佳成绩。

更多信息请参考论文 ERNIE-GEN:An Enhanced Multi-Flow Pre-training and Fine-tuning Framework for Natural Language Generation

PaddleNLP目前支持ernie-gen-base-en, ernie-gen-large-en, ernie-gen-large-en-430g三种生成模型,同时支持加载PaadleNLP transformer类预训练模型中的所有的非生成模型参数作热启动。由于本文执行的是中文古诗的生成,因此采用ernie-1.0中文模型进行热启动。


(二)、环境配置

本示例基于飞桨开源框架2.0版本。

import paddle
import paddlenlp
from paddlenlp.transformers import ErnieForGeneration

# paddle.set_device('gpu')
model = ErnieForGeneration.from_pretrained("ernie-1.0")

输出结果如下图1所示: 【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


一、数据准备

数据来源于CNN/Daily Mail 需要自行获得授权加载 或者使用类似的新闻摘要数据。格式为 文本+"\t"+摘要


(一)、数据集下载

from paddlenlp.datasets import load_dataset

def read(data_path):
    with open(data_path, 'r', encoding='utf-8') as f:
        # 跳过列名
        next(f)
        for line in f:
            words, labels = line.strip('\n').split('\t')
            words = "\002".join(list(words))
            labels = "\002".join(list(labels))
            yield {'tokens': words, 'labels': labels}

# data_path为read()方法的参数
train_dataset = load_dataset(read, data_path='data/data83012/news_summary.txt',lazy=False)
dev_dataset = load_dataset(read, data_path='data/data83012/news_summary_toy.txt',lazy=False)


# Example
print(train_dataset[0]['tokens'])
print(train_dataset[0]['labels'])

输出结果如下图2所示:

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


(二)、数据预处理

此阶段将原始数据处理成模型可以读入的格式。

ERNIE-GEN的输入类似BERT的输入,需要准备切词器,将明文处理为相应的id。

PaddleNLP内置了ErnieTokenizer,通过调用其encode方法可以直接得到输入的input_ids和segment_ids。


from copy import deepcopy
import numpy as np
from paddlenlp.transformers import ErnieTokenizer

tokenizer = ErnieTokenizer.from_pretrained("ernie-1.0")
# ERNIE-GEN中填充了[ATTN] token作为预测位,由于ERNIE 1.0没有这一token,我们采用[MASK]作为填充
attn_id = tokenizer.vocab['[MASK]']
tgt_type_id = 1

# 设置最大输入、输出长度
max_encode_len = 200
max_decode_len =30

def convert_example(example):
    """convert an example into necessary features"""

    encoded_src = tokenizer.encode(
        example['tokens'], max_seq_len=max_encode_len, pad_to_max_seq_len=False)
    src_ids, src_sids = encoded_src["input_ids"], encoded_src["token_type_ids"]
    src_pids = np.arange(len(src_ids))

    encoded_tgt = tokenizer.encode(
        example['labels'],
        max_seq_len=max_decode_len,
        pad_to_max_seq_len=False)
    tgt_ids, tgt_sids = encoded_tgt["input_ids"], encoded_tgt[
        "token_type_ids"]
    tgt_ids = np.array(tgt_ids)
    tgt_sids = np.array(tgt_sids) + tgt_type_id
    tgt_pids = np.arange(len(tgt_ids)) + len(src_ids)

    attn_ids = np.ones_like(tgt_ids) * attn_id
    tgt_labels = tgt_ids

    return (src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids,
            attn_ids, tgt_labels)

# 将预处理逻辑作用于数据集
train_dataset = train_dataset.map(convert_example)
dev_dataset = dev_dataset.map(convert_example)

输出结果如下图3所示:

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


(三)、组patch

接下来需要组batch,并准备ERNIE-GEN额外需要的Attention Mask矩阵


from paddle.io import DataLoader
from paddlenlp.data import Stack, Tuple, Pad


def gen_mask(batch_ids, mask_type='bidi', query_len=None, pad_value=0):
    if query_len is None:
        query_len = batch_ids.shape[1]
    if mask_type != 'empty':
        mask = (batch_ids != pad_value).astype(np.float32)
        mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
        if mask_type == 'causal':
            assert query_len == batch_ids.shape[1]
            mask = np.tril(mask)
        elif mask_type == 'causal_without_diag':
            assert query_len == batch_ids.shape[1]
            mask = np.tril(mask, -1)
        elif mask_type == 'diag':
            assert query_len == batch_ids.shape[1]
            mask = np.stack([np.diag(np.diag(m)) for m in mask], 0)
    else:
        mask_type == 'empty'
        mask = np.zeros_like(batch_ids).astype(np.float32)
        mask = np.tile(np.expand_dims(mask, 1), [1, query_len, 1])
    return mask


def after_padding(args):
    '''
    attention mask:
    ***  src,  tgt, attn
    src  00,   01,   11
    tgt  10,   11,   12
    attn 20,   21,   22

    ***   s1, s2 | t1 t2 t3| attn1 attn2 attn3
    s1    1,  1  | 0, 0, 0,| 0,    0,    0,
    s2    1,  1  | 0, 0, 0,| 0,    0,    0,
    -
    t1    1,  1, | 1, 0, 0,| 0,    0,    0,
    t2    1,  1, | 1, 1, 0,| 0,    0,    0,
    t3    1,  1, | 1, 1, 1,| 0,    0,    0,
    -
    attn1 1,  1, | 0, 0, 0,| 1,    0,    0,
    attn2 1,  1, | 1, 0, 0,| 0,    1,    0,
    attn3 1,  1, | 1, 1, 0,| 0,    0,    1,

    for details, see Fig3. https://arxiv.org/abs/2001.11314
    '''
    src_ids, src_pids, src_sids, tgt_ids, tgt_pids, tgt_sids, attn_ids, tgt_labels = args
    src_len = src_ids.shape[1]
    tgt_len = tgt_ids.shape[1]
    mask_00 = gen_mask(src_ids, 'bidi', query_len=src_len)
    mask_01 = gen_mask(tgt_ids, 'empty', query_len=src_len)
    mask_02 = gen_mask(attn_ids, 'empty', query_len=src_len)

    mask_10 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
    mask_11 = gen_mask(tgt_ids, 'causal', query_len=tgt_len)
    mask_12 = gen_mask(attn_ids, 'empty', query_len=tgt_len)

    mask_20 = gen_mask(src_ids, 'bidi', query_len=tgt_len)
    mask_21 = gen_mask(tgt_ids, 'causal_without_diag', query_len=tgt_len)
    mask_22 = gen_mask(attn_ids, 'diag', query_len=tgt_len)

    mask_src_2_src = mask_00
    mask_tgt_2_srctgt = np.concatenate([mask_10, mask_11], 2)
    mask_attn_2_srctgtattn = np.concatenate([mask_20, mask_21, mask_22], 2)

    raw_tgt_labels = deepcopy(tgt_labels)
    tgt_labels = tgt_labels[np.where(tgt_labels != 0)]
    return (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids,
            mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
            tgt_labels, raw_tgt_labels)

# 使用fn函数对convert_example返回的sample中对应位置的ids做padding,之后调用after_padding构造Attention Mask矩阵
batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))

batch_size = 16

train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=batchify_fn,
        return_list=True)

dev_data_loader = DataLoader(
        dataset=dev_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=batchify_fn,
        return_list=True)

输出结果如下图4所示:

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


二、模型配置

(一)、优化器

一切准备就绪后,就可以将数据喂给模型,不断更新模型参数了。在训练过程中可以使用PaddleNLP提供的logger对象,可以输出带时间信息的日志。


import os
import time

from paddlenlp.utils.log import logger

global_step = 1
logging_steps = 100
save_steps = 1000
output_dir = "save_dir"
tic_train = time.time()
for epoch in range(num_epochs):
    for step, batch in enumerate(train_data_loader, start=1):
        (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids, attn_ids,
            mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
            tgt_labels, _) = batch
        # import pdb; pdb.set_trace()
        _, __, info = model(
            src_ids,
            sent_ids=src_sids,
            pos_ids=src_pids,
            attn_bias=mask_src_2_src,
            encode_only=True)
        cached_k, cached_v = info['caches']
        _, __, info = model(
            tgt_ids,
            sent_ids=tgt_sids,
            pos_ids=tgt_pids,
            attn_bias=mask_tgt_2_srctgt,
            past_cache=(cached_k, cached_v),
            encode_only=True)
        cached_k2, cached_v2 = info['caches']
        past_cache_k = [
            paddle.concat([k, k2], 1) for k, k2 in zip(cached_k, cached_k2)
        ]
        past_cache_v = [
            paddle.concat([v, v2], 1) for v, v2 in zip(cached_v, cached_v2)
        ]
        loss, _, __ = model(
            attn_ids,
            sent_ids=tgt_sids,
            pos_ids=tgt_pids,
            attn_bias=mask_attn_2_srctgtattn,
            past_cache=(past_cache_k, past_cache_v),
            tgt_labels=tgt_labels,
            tgt_pos=paddle.nonzero(attn_ids == attn_id))

        if global_step % logging_steps == 0:
            logger.info(
                "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
                % (global_step, epoch, step, loss, logging_steps /
                    (time.time() - tic_train), lr_scheduler.get_lr()))
            tic_train = time.time()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_gradients()
        if global_step % save_steps == 0:
            output_dir = os.path.join(output_dir,
                                        "model_%d" % global_step)
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            model.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)

        global_step += 1

部分输出结果如下图5所示:

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


(二)、解码逻辑

ERNIE-GEN采用填充生成的方式进行预测,在解码的时候我们需要实现这一方法。 ​ 在这里我们采用贪心搜索的方式进行解码,如需采用beam search方法,请参考example


`def gen_bias(encoder_inputs, decoder_inputs, step): decoder_bsz, decoder_seqlen = decoder_inputs.shape[:2] encoder_bsz, encoder_seqlen = encoder_inputs.shape[:2] attn_bias = paddle.reshape( paddle.arange( 0, decoder_seqlen, 1, dtype='float32') + 1, [1, -1, 1]) decoder_bias = paddle.cast( (paddle.matmul( attn_bias, 1. / attn_bias, transpose_y=True) >= 1.), 'float32') #[1, decoderlen, decoderlen] encoder_bias = paddle.unsqueeze( paddle.cast(paddle.ones_like(encoder_inputs), 'float32'), [1]) #[bsz, 1, encoderlen] encoder_bias = paddle.expand( encoder_bias, [encoder_bsz, decoder_seqlen, encoder_seqlen]) #[bsz,decoderlen, encoderlen] decoder_bias = paddle.expand( decoder_bias, [decoder_bsz, decoder_seqlen, decoder_seqlen]) #[bsz, decoderlen, decoderlen] if step > 0: bias = paddle.concat([ encoder_bias, paddle.ones([decoder_bsz, decoder_seqlen, step], 'float32'), decoder_bias ], -1) else: bias = paddle.concat([encoder_bias, decoder_bias], -1) return bias

@paddle.no_grad() def greedy_search_infilling(model, q_ids, q_sids, sos_id, eos_id, attn_id, pad_id, unk_id, vocab_size, max_encode_len=640, max_decode_len=100, tgt_type_id=3): _, logits, info = model(q_ids, q_sids) d_batch, d_seqlen = q_ids.shape seqlen = paddle.sum(paddle.cast(q_ids != 0, 'int64'), 1, keepdim=True) has_stopped = np.zeros([d_batch], dtype=np.bool) gen_seq_len = np.zeros([d_batch], dtype=np.int64) output_ids = []

past_cache = info['caches']

cls_ids = paddle.ones([d_batch], dtype='int64') * sos_id
attn_ids = paddle.ones([d_batch], dtype='int64') * attn_id
ids = paddle.stack([cls_ids, attn_ids], -1)
for step in range(max_decode_len):
    bias = gen_bias(q_ids, ids, step)
    pos_ids = paddle.to_tensor(
        np.tile(
            np.array(
                [[step, step + 1]], dtype=np.int64), [d_batch, 1]))
    pos_ids += seqlen
    _, logits, info = model(
        ids,
        paddle.ones_like(ids) * tgt_type_id,
        pos_ids=pos_ids,
        attn_bias=bias,
        past_cache=past_cache)

    if logits.shape[-1] > vocab_size:
        logits[:, :, vocab_size:] = 0
    logits[:, :, pad_id] = 0
    logits[:, :, unk_id] = 0
    logits[:, :, attn_id] = 0

    gen_ids = paddle.argmax(logits, -1)

    past_cached_k, past_cached_v = past_cache
    cached_k, cached_v = info['caches']
    cached_k = [
        paddle.concat([pk, k[:, :1, :]], 1)
        for pk, k in zip(past_cached_k, cached_k)
    ]  # concat cached 
    cached_v = [
        paddle.concat([pv, v[:, :1, :]], 1)
        for pv, v in zip(past_cached_v, cached_v)
    ]
    past_cache = (cached_k, cached_v)

    gen_ids = gen_ids[:, 1]
    ids = paddle.stack([gen_ids, attn_ids], 1)

    gen_ids = gen_ids.numpy()
    has_stopped |= (gen_ids == eos_id).astype(np.bool)
    gen_seq_len += (1 - has_stopped.astype(np.int64))
    output_ids.append(gen_ids.tolist())
    if has_stopped.all():
        break
output_ids = np.array(output_ids).transpose([1, 0])
return output_ids`


四、模型评估

评估阶段会调用解码逻辑进行解码,然后计算预测结果得分衡量模型效果。paddlenlp.metrics中包含了Rouge1、Rouge2等指标,在这里我们选用Rouge1指标。


from tqdm import tqdm

from paddlenlp.metrics import Rouge1


rouge1 = Rouge1()

vocab = tokenizer.vocab
eos_id = vocab[tokenizer.sep_token]
sos_id = vocab[tokenizer.cls_token]
pad_id = vocab[tokenizer.pad_token]
unk_id = vocab[tokenizer.unk_token]
vocab_size = len(vocab)

evaluated_sentences_ids = []
reference_sentences_ids = []

logger.info("Evaluating...")
model.eval()
for data in tqdm(dev_data_loader):
    (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
        raw_tgt_labels) = data  # never use target when infer
    output_ids = greedy_search_infilling(
        model,
        src_ids,
        src_sids,
        eos_id=eos_id,
        sos_id=sos_id,
        attn_id=attn_id,
        pad_id=pad_id,
        unk_id=unk_id,
        vocab_size=vocab_size,
        max_decode_len=max_decode_len,
        max_encode_len=max_encode_len,
        tgt_type_id=tgt_type_id)

    for ids in output_ids.tolist():
        if eos_id in ids:
            ids = ids[:ids.index(eos_id)]
        eval(ids)

    for ids in raw_tgt_labels.numpy().tolist():
        ids = ids[1:ids.index(eos_id)]
        reference_sentences_ids.append(ids)

score = rouge1.score(evaluated_sentences_ids, reference_sentences_ids)

logger.info("Rouge-1: %.5f" % (score * 100))

输出结果如下图6所示:

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


五、模型预测

对于生成任务,评估指标并不能很好地提现模型效果,下面我们直接观察模型的预测效果。


evaluated_sentences = []
reference_sentences = []
for ids in reference_sentences_ids[:5]:
    reference_sentences.append(''.join(vocab.to_tokens(ids)))
for ids in evaluated_sentences_ids[:5]:
    eval(''.join(vocab.to_tokens(ids)))
logger.info(reference_sentences)
logger.info(evaluated_sentences)

输出结果如下图7所示:

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘


总结

本系列文章内容为根据清华社出版的《自然语言处理实践》所作的相关笔记和感悟,其中代码均为基于百度飞桨开发,若有任何侵权和不妥之处,请私信于我,定积极配合处理,看到必回!!!

最后,引用本次活动的一句话,来作为文章的结语~( ̄▽ ̄~)~:

【**学习的最大理由是想摆脱平庸,早一天就多一份人生的精彩;迟一天就多一天平庸的困扰。**】

【自然语言处理(NLP)】基于ERNIE-GEN的中文自动文摘