查看原文
其他

NLP.TM[33] | 纠错:pycorrector的错误检测

机智的叉烧 CS的陋室 2022-08-08

【NLP.TM】


本人有关自然语言处理和文本挖掘方面的学习和笔记,欢迎大家关注。


往期回顾


纠错是NLP中的一个看着不是很火但其实在现实应用中非常重要的一个部分,在一个强NLP以来的项目(如搜索)发展至中期,纠错就会成为一个效果提升的新增长点,经过统计,在微博等新媒体领域中,文本出错概率在2%左右,在语音识别领域中,出错率最高可达8-10%(数据来自:https://zhuanlan.zhihu.com/p/159101860),从这个比例来看,如果能修正这些错误,对效果的提升无疑是巨大的,那么我们来看看,纠错任务是怎么做的。

文章较长,懒人目录再现:

  • pycorrector简介
  • pycorrector的纠错思路
  • 混淆词典
  • 未登录词检测
  • 语言模型
  • 结果输出
  • 小结

pycorrector简介

pycorrector是非常基础的纠错模块工具,里面已经实现了一些非常通用的纠错方法,用里面的方法来做基线其实其实非常方便。

连接先放在这里:https://github.com/shibing624/pycorrector

他的使用方法其实也比较简单:

import pycorrector

corrected_sent, detail = pycorrector.correct('少先队员因该为老人让坐')
print(corrected_sent, detail)

这是一个非常简单的官方case,详情还是可以去github里面去看看。

pycorrect的纠错思路

其实pycorrect里面造了很多飞机,不过实质上正式使用的还是非常经典的方法,来看看它的主函数具体思路是什么样的。

def correct(self, text, include_symbol=True, num_fragment=1, threshold=57, **kwargs):
    """
    句子改错
    :param text: str, query 文本
    :param include_symbol: bool, 是否包含标点符号
    :param num_fragment: 纠错候选集分段数, 1 / (num_fragment + 1)
    :param threshold: 语言模型纠错ppl阈值
    :param kwargs: ...
    :return: text (str)改正后的句子, list(wrong, right, begin_idx, end_idx)
    """

    text_new = ''
    details = []
    self.check_corrector_initialized()
    # 编码统一,utf-8 to unicode
    text = convert_to_unicode(text)
    # 长句切分为短句
    blocks = self.split_2_short_text(text, include_symbol=include_symbol)
    for blk, idx in blocks:
        maybe_errors = self.detect_short(blk, idx)
        for cur_item, begin_idx, end_idx, err_type in maybe_errors:
            # 纠错,逐个处理
            before_sent = blk[:(begin_idx - idx)]
            after_sent = blk[(end_idx - idx):]

            # 困惑集中指定的词,直接取结果
            if err_type == ErrorType.confusion:
                corrected_item = self.custom_confusion[cur_item]
            else:
                # 取得所有可能正确的词
                candidates = self.generate_items(cur_item, fragment=num_fragment)
                if not candidates:
                    continue
                corrected_item = self.get_lm_correct_item(cur_item, candidates, before_sent, after_sent,
                                                          threshold=threshold)
            # output
            if corrected_item != cur_item:
                blk = before_sent + corrected_item + after_sent
                detail_word = [cur_item, corrected_item, begin_idx, end_idx]
                details.append(detail_word)
        text_new += blk
    details = sorted(details, key=operator.itemgetter(2))
    return text_new, details

这里面其实还是比较明确的:

  • 分句。一个长句分成多个断句。
  • 对每个短句进行错误检测detect_short
  • 错误点召回可能正确的词。
  • 召回后筛选最佳结果。

在这个框架下,来看看具体pycorrect的错误检测是怎么做的。

混淆词典

直接看源码:

# 自定义混淆集加入疑似错误词典
for confuse in self.custom_confusion:
    idx = sentence.find(confuse)
    if idx > -1:
        maybe_err = [confuse, idx + start_idx, idx + len(confuse) + start_idx, ErrorType.confusion]
        self._add_maybe_error_item(maybe_err, maybe_errors)

这块其实还是比较简单的,其实就是用户自定义了一个词典,这个词典作者叫做混淆词典,我更愿意叫做改写词典,遇到了key,就去找v,直接做这种改写。

不过个人感觉这种遍历整个整个词典然后find的方法复杂度可能比较高,如果是我我还是比较喜欢最大逆向匹配的方式来查字典。

未登录词检测

同样上代码:

if self.is_word_error_detect:
    # 切词
    tokens = self.tokenizer.tokenize(sentence)
    # 未登录词加入疑似错误词典
    for token, begin_idx, end_idx in tokens:
        # pass filter word
        if self.is_filter_token(token):
            continue
        # pass in dict
        if token in self.word_freq:
            continue
        maybe_err = [token, begin_idx + start_idx, end_idx + start_idx, ErrorType.word]
        self._add_maybe_error_item(maybe_err, maybe_errors)

注释其实还是非常友好的,其实就这几个步骤:

  • 切词。
  • 跳过特定词汇的检测。
  • 查字典看是否有低频词(未登录词)出现。
  • 结果整理。

首先就是切词,这里的切词是一个函数,我们也来看看他具体切词是怎么切的:

class Tokenizer(object):
    def __init__(self, dict_path='', custom_word_freq_dict=None, custom_confusion_dict=None):
        self.model = jieba
        self.model.default_logger.setLevel(logging.ERROR)
        # 初始化大词典
        if os.path.exists(dict_path):
            self.model.set_dictionary(dict_path)
        # 加载用户自定义词典
        if custom_word_freq_dict:
            for w, f in custom_word_freq_dict.items():
                self.model.add_word(w, freq=f)

        # 加载混淆集词典、
        if custom_confusion_dict:
            for k, word in custom_confusion_dict.items():
                # 添加到分词器的自定义词典中
                self.model.add_word(k)
                self.model.add_word(word)

    def tokenize(self, unicode_sentence, mode="search"):
        """
        切词并返回切词位置, search mode用于错误扩召回
        :param unicode_sentence: query
        :param mode: search, default, ngram
        :param HMM: enable HMM
        :return: (w, start, start + width) model='default'
        """

        if mode == 'ngram':
            n = 2
            result_set = set()
            tokens = self.model.lcut(unicode_sentence)
            tokens_len = len(tokens)
            start = 0
            for i in range(0, tokens_len):
                w = tokens[i]
                width = len(w)
                result_set.add((w, start, start + width))
                for j in range(i, i + n):
                    gram = "".join(tokens[i:j + 1])
                    gram_width = len(gram)
                    if i + j > tokens_len:
                        break
                    result_set.add((gram, start, start + gram_width))
                start += width
            results = list(result_set)
            result = sorted(results, key=lambda x: x[-1])
        else:
            result = list(self.model.tokenize(unicode_sentence, mode=mode))
        return result

看着很高端,稍微看看源码其实就可以发现用的是以jieba为基础的操作,只不过多了一种n-gram切词而已,其实就是切词以后按照n-gram拼装而已。

切完词后,就是过滤一些不需要检测的词汇,主要是一些数字之类的,来看看具体有哪些:

@staticmethod
def is_filter_token(token):
    result = False
    # pass blank
    if not token.strip():
        result = True
    # pass num
    if token.isdigit():
        result = True
    # pass alpha
    if is_alphabet_string(token.lower()):
        result = True
    # pass not chinese
    if not is_chinese_string(token):
        result = True
    return result
  • 空字符串
  • 数字
  • 字母
  • 非中文

然后就是判断是否是低频词,这个就比较容易,他是构建了一个词典,直接判断是否在里面就好了。

语言模型

NLP领域最基础的东西就要数语言模型了,这里的假设其实是人输入的语言大都是常用的,如果出现了不太常用的东西,其实说明是有错的,带着这个假设,我们来看看利用这个方法是怎么判错的。

# 语言模型检测疑似错误字
try:
    ngram_avg_scores = []
    for n in [23]:
        scores = []
        for i in range(len(sentence) - n + 1):
            word = sentence[i:i + n]
            score = self.ngram_score(list(word))
            scores.append(score)
        if not scores:
            continue
        # 移动窗口补全得分
        for _ in range(n - 1):
            scores.insert(0, scores[0])
            scores.append(scores[-1])
        avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
        ngram_avg_scores.append(avg_scores)

    if ngram_avg_scores:
        # 取拼接后的n-gram平均得分
        sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))
        # 取疑似错字信息
        for i in self._get_maybe_error_index(sent_scores):
            token = sentence[i]
            # pass filter word
            if self.is_filter_token(token):
                continue
            # pass in stop word dict
            if token in self.stopwords:
                continue
            # token, begin_idx, end_idx, error_type
            maybe_err = [token, i + start_idx, i + start_idx + 1,
                         ErrorType.char]
            self._add_maybe_error_item(maybe_err, maybe_errors)
except IndexError as ie:
    logger.warn("index error, sentence:" + sentence + str(ie))
except Exception as e:
    logger.warn("detect error, sentence:" + sentence + str(e))

首先这个是基于字来判断的,所以不需要切词,直接把字符串一个一个的拼接成n-gram即可。

要分析整个句子中每个位点字合理,是需要看上下文的,这里分别采用了2-gram和3-gram进行了分析,分别计算了一个叫做ngram_score的东西,具体是这样的:

def ngram_score(self, chars):
    """
    取n元文法得分
    :param chars: list, 以词或字切分
    :return:
    """

    self.check_detector_initialized()
    return self.lm.score(' '.join(chars), bos=False, eos=False)

这里使用的是kenlm来训练的语言模型,然后用score进行得分计算,这个得分实质上就是分析这个句子组合产生的可能性,概率当然就是在之间了,然后取对数,因此这个得分就是一个非正数了,越接近0,说明这个组合出现的可能性越大,越不可能有错了。

另外,为了保证整个句子的完整性,是需要padding的,代码里做了一个移动窗口的处理,直接看可能有些难懂,但是知道了padding,应该会好明白一些:

# 移动窗口补全得分
for _ in range(n - 1):
    scores.insert(0, scores[0])
    scores.append(scores[-1])

然后就对分数进行根据句子长度的均值计算,计算完之后分别保存了每个字的2-gram得分和3-gram得分,然后后续取了这两个分数的均值,这里的代码这么看:

avg_scores = [sum(scores[i:i + n]) / len(scores[i:i + n]) for i in range(len(sentence))]
ngram_avg_scores.append(avg_scores)

if ngram_avg_scores:
    # 取拼接后的n-gram平均得分
    sent_scores = list(np.average(np.array(ngram_avg_scores), axis=0))

然后就会开始对这个分数进行分析,最终抽取可能有问题的位点,使用的函数就是_get_maybe_error_index

@staticmethod
def _get_maybe_error_index(scores, ratio=0.6745, threshold=2):
    """
    取疑似错字的位置,通过平均绝对离差(MAD)
    :param scores: np.array
    :param ratio: 正态分布表参数
    :param threshold: 阈值越小,得到疑似错别字越多
    :return: 全部疑似错误字的index: list
    """

    result = []
    scores = np.array(scores)
    if len(scores.shape) == 1:
        scores = scores[:, None]
    median = np.median(scores, axis=0)  # get median of all scores
    margin_median = np.abs(scores - median).flatten()  # deviation from the median
    # 平均绝对离差值
    med_abs_deviation = np.median(margin_median)
    if med_abs_deviation == 0:
        return result
    y_score = ratio * margin_median / med_abs_deviation
    # 打平
    scores = scores.flatten()
    maybe_error_indices = np.where((y_score > threshold) & (scores < median))
    # 取全部疑似错误字的index
    result = list(maybe_error_indices[0])
    return result

思路其实大概说了,就是基于平均离差来算,这其实就是常用异常检测的MAD。说白了就是整个句子,大部分情况是不会出错的,正常情况下打分就会在特定的一个范围内,但是出错的位置的打分会距离这个打分很远(可以理解为和常规语境和语言水平差别很大),我们需要把这几个打分比较远的对应位置提取出来。

另外这里蛮有意思的是,可以看到作者对numpy比较熟悉,可以看看里面这些操作。

结果输出

然后就是一些整理结果输出的操作了,基本的数据处理还是比较容易的,直接看看最终的输出格式吧

import pycorrector

idx_errors = pycorrector.detect('少先队员因该为老人让坐')
print(idx_errors)

# 输出:[['因该', 4, 6, 'word'], ['坐', 10, 11, 'char']]

会把他定的位置和错误类型给指出来,最终只需要整理出这个格式就行。

小结

这里给大家介绍的是pycorrector内baseline的检测方法,让大家理解最基本的错误识别方式。


您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存