热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

fasttext文本分类python实现_基于FastText进行文本分类

liunx版本下操作:$gitclonehttps:github.comfacebookresearchfastText.git$cdfastText$pipinst

liunx版本下操作:

$ git clone https://github.com/facebookresearch/fastText.git

$ cd fastText

$ pip install .

安装成功后的导入:

新建test.py文件,写入:

import fastText.FastText as fasttext(可能会瞟红线)

新增:最近发现fasttext的github更新了,引入方式发生了变化,如果上述引入报错,改成 import fasttext.FastText as fasttext

新增:现在安装直接 pip install fasttext,导入直接 import fasttext 就行

保存后退出并运行:

python3 test.py

没报错说明安装成功第二步:准备数据集我这里用的是清华的新闻数据集(由于完整数据集较大,这里只取部分数据)

数据链接:点击获取网盘数据 提取码:byoi(data.txt为数据集,stopwords.txt为停用词)

下载好后的数据格式为:

对应的标签分别为(由于只是用小部分数据,所以data.txt只包含部分标签):

mapper_tag = {

'财经': 'Finance',

'彩票': 'Lottery',

'房产': 'Property',

'股票': 'Shares',

'家居': 'Furnishing',

'教育': 'Education',

'科技': 'Technology',

'社会': 'Sociology',

'时尚': 'Fashion',

'时政': 'Affairs',

'体育': 'Sports',

'星座': 'Constellation',

'游戏': 'Game',

'娱乐': 'Entertainment'

}第三步:数据预处理由于data.txt已经经过了分词和去停用词的处理,所以这里只需要对数据进行切割为训练集和测试集即可。

分词和去停用词的工具代码(运行时不需要执行此部分代码):

import re

from types import MethodType, FunctionType

import jieba

def clean_txt(raw):

fil = re.compile(r"[^0-9a-zA-Z\u4e00-\u9fa5]+")

return fil.sub(' ', raw)

def seg(sentence, sw, apply=None):

if isinstance(apply, FunctionType) or isinstance(apply, MethodType):

sentence = apply(sentence)

return ' '.join([i for i in jieba.cut(sentence) if i.strip() and i not in sw])

def stop_words():

with open('stop_words.txt', 'r', encoding='utf-8') as swf:

return [line.strip() for line in swf]

# 对某个sentence进行处理:

content = '上海天然橡胶期价周三再创年内新高,主力合约突破21000元/吨重要关口。'

res = seg(content.lower().replace('\n', ''), stop_words(), apply=clean_txt)切割数据(这里我是先将txt文件转换成csv文件,方便后面的计算)

from random import shuffle

import pandas as pd

class _MD(object):

mapper = {

str: '',

int: 0,

list: list,

dict: dict,

set: set,

bool: False,

float: .0

}

def __init__(self, obj, default=None):

self.dict = {}

assert obj in self.mapper, \

'got a error type'

self.t = obj

if default is None:

return

assert isinstance(default, obj), \

f'default ({default}) must be {obj}'

self.v = default

def __setitem__(self, key, value):

self.dict[key] = value

def __getitem__(self, item):

if item not in self.dict and hasattr(self, 'v'):

self.dict[item] = self.v

return self.v

elif item not in self.dict:

if callable(self.mapper[self.t]):

self.dict[item] = self.mapper[self.t]()

else:

self.dict[item] = self.mapper[self.t]

return self.dict[item]

return self.dict[item]

def defaultdict(obj, default=None):

return _MD(obj, default)

class TransformData(object):

def to_csv(self, handler, output, index=False):

dd = defaultdict(list)

for line in handler:

label, content = line.split(',', 1)

dd[label.strip('__label__').strip()].append(content.strip())

df = pd.DataFrame()

for key in dd.dict:

col = pd.Series(dd[key], name=key)

df = pd.concat([df, col], axis=1)

return df.to_csv(output, index=index, encoding='utf-8')

def split_train_test(source, auth_data=False):

if not auth_data:

train_proportion = 0.8

else:

train_proportion = 0.98

basename = source.rsplit('.', 1)[0]

train_file = basename + '_train.txt'

test_file = basename + '_test.txt'

handel = pd.read_csv(source, index_col=False, low_memory=False)

train_data_set = []

test_data_set = []

for head in list(handel.head()):

train_num = int(handel[head].dropna().__len__() * train_proportion)

sub_list = [f'__label__{head} , {item.strip()}\n' for item in handel[head].dropna().tolist()]

train_data_set.extend(sub_list[:train_num])

test_data_set.extend(sub_list[train_num:])

shuffle(train_data_set)

shuffle(test_data_set)

with open(train_file, 'w', encoding='utf-8') as trainf,\

open(test_file, 'w', encoding='utf-8') as testf:

for tds in train_data_set:

trainf.write(tds)

for i in test_data_set:

testf.write(i)

return train_file, test_file

# 转化成csv

td = TransformData()

handler = open('data.txt')

td.to_csv(handler, 'data.csv')

handler.close()

# 将csv文件切割,会生成两个文件(data_train.txt和data_test.txt)

train_file, test_file = split_train_test('data.csv', auth_data=True)第四步:训练模型

import fastText.FastText as fasttext

def train_model(ipt=None, opt=None, model='', dim=100, epoch=5, lr=0.1, loss='softmax'):

np.set_printoptions(suppress=True)

if os.path.isfile(model):

classifier = fasttext.load_model(model)

else:

classifier = fasttext.train_supervised(ipt, label='__label__', dim=dim, epoch=epoch,

lr=lr, wordNgrams=2, loss=loss)

"""训练一个监督模型, 返回一个模型对象@param input: 训练数据文件路径@param lr: 学习率@param dim: 向量维度@param ws: cbow模型时使用@param epoch: 次数@param minCount: 词频阈值, 小于该值在初始化时会过滤掉@param minCountLabel: 类别阈值,类别小于该值初始化时会过滤掉@param minn: 构造subword时最小char个数@param maxn: 构造subword时最大char个数@param neg: 负采样@param wordNgrams: n-gram个数@param loss: 损失函数类型, softmax, ns: 负采样, hs: 分层softmax@param bucket: 词扩充大小, [A, B]: A语料中包含的词向量, B不在语料中的词向量@param thread: 线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出@param lrUpdateRate: 学习率更新@param t: 负采样阈值@param label: 类别前缀@param verbose: ??@param pretrainedVectors: 预训练的词向量文件路径, 如果word出现在文件夹中初始化不再随机@return model object"""

classifier.save_model(opt)

return classifier

dim = 100

lr = 5

epoch = 5

model = f'data_dim{str(dim)}_lr0{str(lr)}_iter{str(epoch)}.model'

classifier = train_model(ipt='data_train.txt',

opt=model,

model=model,

dim=dim, epoch=epoch, lr=0.5

)

result = classifier.test('data_test.txt')

print(result)

# 整体的结果为(测试数据量,precision,recall):

(9885, 0.9740010116337886, 0.9740010116337886)可以看出结果相当高,由于上面是将整体作为测试,fasttext只给出整体的结果,precision和recall是相同的,下面我们测试每个标签的precision、recall和F1值。

def cal_precision_and_recall(file='data_test.txt'):

precision = defaultdict(int, 1)

recall = defaultdict(int, 1)

total = defaultdict(int, 1)

with open(file) as f:

for line in f:

label, content = line.split(',', 1)

total[label.strip().strip('__label__')] += 1

labels2 = classifier.predict([seg(sentence=content.strip(), sw='', apply=clean_txt)])

pre_label, sim = labels2[0][0][0], labels2[1][0][0]

recall[pre_label.strip().strip('__label__')] += 1

if label.strip() == pre_label.strip():

precision[label.strip().strip('__label__')] += 1

print('precision', precision.dict)

print('recall', recall.dict)

print('total', total.dict)

for sub in precision.dict:

pre = precision[sub] / total[sub]

rec = precision[sub] / recall[sub]

F1 = (2 * pre * rec) / (pre + rec)

print(f"{sub.strip('__label__')} precision: {str(pre)} recall: {str(rec)} F1: {str(F1)}")结果:

precision {'Technology': 983, 'Education': 972, 'Shares': 988, 'Affairs': 975, 'Entertainment': 991, 'Financ': 982, 'Furnishing': 975, 'Gam': 841, 'Sociology': 946, 'Sports': 978}

recall {'Technology': 992, 'Education': 1013, 'Shares': 1007, 'Affairs': 995, 'Entertainment': 1022, 'Financ': 1001, 'Furnishing': 997, 'Gam': 854, 'Sociology': 1025, 'Sports': 989}

total {'Technology': 1001, 'Education': 1001, 'Shares': 1001, 'Affairs': 1001, 'Entertainment': 1001, 'Financ': 1001, 'Furnishing': 1001, 'Gam': 876, 'Sociology': 1001, 'Sports': 1001, 'Property': 11}

Technology precision: 0.9820179820179821 recall: 0.9909274193548387 F1: 0.9864525840441545

Education precision: 0.971028971028971 recall: 0.9595261599210266 F1: 0.9652432969215492

Shares precision: 0.987012987012987 recall: 0.9811320754716981 F1: 0.9840637450199202

Affairs precision: 0.974025974025974 recall: 0.9798994974874372 F1: 0.9769539078156312

Entertainment precision: 0.99000999000999 recall: 0.9696673189823874 F1: 0.9797330696984675

Financ precision: 0.981018981018981 recall: 0.981018981018981 F1: 0.981018981018981

Furnishing precision: 0.974025974025974 recall: 0.9779338014042126 F1: 0.975975975975976

Gam precision: 0.9600456621004566 recall: 0.9847775175644028 F1: 0.9722543352601155

Sociology precision: 0.945054945054945 recall: 0.9229268292682927 F1: 0.9338598223099703

Sports precision: 0.977022977022977 recall: 0.9888776541961577 F1: 0.9829145728643216

可以看出结果非常可观,fasttext很强大...整合后的代码:

def main(source):

basename = source.rsplit('.', 1)[0]

csv_file = basename + '.csv'

td = TransformData()

handler = open(source)

td.to_csv(handler, csv_file)

handler.close()

train_file, test_file = split_train_test(csv_file)

dim = 100

lr = 5

epoch = 5

model = f'data/data_dim{str(dim)}_lr0{str(lr)}_iter{str(epoch)}.model'

classifier = train_model(ipt=train_file,

opt=model,

model=model,

dim=dim, epoch=epoch, lr=0.5

)

result = classifier.test(test_file)

print(result)

cal_precision_and_recall(test_file)

if __name__ == '__main__':

main('data.txt')



推荐阅读
  • 掌握PHP编程必备知识与技巧——全面教程在当今的PHP开发中,了解并运用最新的技术和最佳实践至关重要。本教程将详细介绍PHP编程的核心知识与实用技巧。首先,确保你正在使用PHP 5.3或更高版本,最好是最新版本,以充分利用其性能优化和新特性。此外,我们还将探讨代码结构、安全性和性能优化等方面的内容,帮助你成为一名更高效的PHP开发者。 ... [详细]
  • 深入浅出 webpack 系列(二):实现 PostCSS 代码的编译与优化
    在前一篇文章中,我们探讨了如何通过基础配置使 Webpack 完成 ES6 代码的编译。本文将深入讲解如何利用 Webpack 实现 PostCSS 代码的编译与优化,包括配置相关插件和加载器,以提升开发效率和代码质量。我们将详细介绍每个步骤,并提供实用示例,帮助读者更好地理解和应用这些技术。 ... [详细]
  • Python 程序转换为 EXE 文件:详细解析 .py 脚本打包成独立可执行文件的方法与技巧
    在开发了几个简单的爬虫 Python 程序后,我决定将其封装成独立的可执行文件以便于分发和使用。为了实现这一目标,首先需要解决的是如何将 Python 脚本转换为 EXE 文件。在这个过程中,我选择了 Qt 作为 GUI 框架,因为之前对此并不熟悉,希望通过这个项目进一步学习和掌握 Qt 的基本用法。本文将详细介绍从 .py 脚本到 EXE 文件的整个过程,包括所需工具、具体步骤以及常见问题的解决方案。 ... [详细]
  • PHP预处理常量详解:如何定义与使用常量 ... [详细]
  • 基于 Bottle 框架构建的幽默应用 —— Python 实践 ... [详细]
  • 利用树莓派畅享落网电台音乐体验
    最近重新拾起了闲置已久的树莓派,这台小巧的开发板已经沉寂了半年多。上个月闲暇时间较多,我决定将其重新启用。恰逢落网电台进行了改版,回忆起之前在树莓派论坛上看到有人用它来播放豆瓣音乐,便萌生了同样的想法。通过一番调试,终于实现了在树莓派上流畅播放落网电台音乐的功能,带来了全新的音乐享受体验。 ... [详细]
  • 本文探讨了BERT模型在自然语言处理领域的应用与实践。详细介绍了Transformers库(曾用名pytorch-transformers和pytorch-pretrained-bert)的使用方法,涵盖了从模型加载到微调的各个环节。此外,还分析了BERT在文本分类、情感分析和命名实体识别等任务中的性能表现,并讨论了其在实际项目中的优势和局限性。 ... [详细]
  • 本文介绍了如何利用Shell脚本高效地部署MHA(MySQL High Availability)高可用集群。通过详细的脚本编写和配置示例,展示了自动化部署过程中的关键步骤和注意事项。该方法不仅简化了集群的部署流程,还提高了系统的稳定性和可用性。 ... [详细]
  • 为了确保iOS应用能够安全地访问网站数据,本文介绍了如何在Nginx服务器上轻松配置CertBot以实现SSL证书的自动化管理。通过这一过程,可以确保应用始终使用HTTPS协议,从而提升数据传输的安全性和可靠性。文章详细阐述了配置步骤和常见问题的解决方法,帮助读者快速上手并成功部署SSL证书。 ... [详细]
  • 在 Linux 环境下,多线程编程是实现高效并发处理的重要技术。本文通过具体的实战案例,详细分析了多线程编程的关键技术和常见问题。文章首先介绍了多线程的基本概念和创建方法,然后通过实例代码展示了如何使用 pthreads 库进行线程同步和通信。此外,还探讨了多线程程序中的性能优化技巧和调试方法,为开发者提供了宝贵的实践经验。 ... [详细]
  • 在使用 `requests` 库进行 HTTP 请求时,如果遇到 `requests.exceptions.SSLError: HTTPSConnectionPool` 错误,通常是因为 SSL 证书验证失败。解决这一问题的方法包括:检查目标网站的 SSL 证书是否有效、更新本地的 CA 证书库、禁用 SSL 验证(不推荐用于生产环境)或使用自定义的 SSL 上下文。此外,确保 `requests` 库和相关依赖项已更新到最新版本,以避免潜在的安全漏洞。 ... [详细]
  • 利用 Python 管道实现父子进程间高效通信 ... [详细]
  • 在Windows环境下离线安装PyTorch GPU版时,首先需确认系统配置,例如本文作者使用的是Win8、CUDA 8.0和Python 3.6.5。用户应根据自身Python和CUDA版本,在PyTorch官网查找并下载相应的.whl文件。此外,建议检查系统环境变量设置,确保CUDA路径正确配置,以避免安装过程中可能出现的兼容性问题。 ... [详细]
  • 如何在Python中高效运用requests模块:详细使用指南与技巧分享
    在Python中,`requests`模块是处理URL请求的强大工具,作为一个第三方库,需要单独安装。本文将详细介绍如何高效地使用`requests`模块,涵盖从基础功能到高级技巧的各个方面,帮助开发者更好地掌握其应用方法,提高开发效率和代码质量。 ... [详细]
  • 如何在 Node.js 环境中将 CSV 数据转换为标准的 JSON 文件格式? ... [详细]
author-avatar
茫茫人海啊啊啊_574
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有