用机器学习过滤微博

2011年12月01日 14:52

@自扯自蛋的写的扯经系列很有意思。作者的博客有早期的扯经的合集。为了看扯经有段时间还天天翻新浪微博。不过后来看不到扯经更新了,以为是作者不写了,直到有一天,作者说在网易微博上还在写!这不是坑爹吗!!!害我追了大半天,结果在别的地方开坑了!总不能又翻一遍网易微博,在作者的碎碎念中找一个个扯经吧,所以一直想把网易微博上的扯经过滤出来。最近一直在做 Stanford 的在线的 Machine Learning 课程 programming exercise 6 中给出了如何用支持向量机(support vector machine,SVM)来过滤垃圾邮件的例子,这不正好用上了,遂有此文,是以为记。

首先要得到作者的所有3560条微博,想用 api 爬,结果发现这也需要 oauth 登录。使用 oauth.net/code/ 上推荐的库 python-oauth2 没成功。直接上网易微博自己的 python 库 t4py 。按照说明改 t4py\tblog\constants.py 里的 CONSUMER_KEY 和 CONSUMER_SECRET 以及 在ACCESS_TOKEN_FILE 所指的文件中加上自己的名字,Acess token 和 Acess token secret。

按照 example 写了这样一个脚本得到自扯自蛋的所有微博:

# -*- coding: utf-8 -*-

import json

from t4py.tblog.tblog import TBlog
from t4py.tblog.constants import CONSUMER_KEY
from t4py.tblog.constants import CONSUMER_SECRET
from t4py.http.oauth import OAuthToken
from t4py.utils.token_util import TokenUtil

util = TokenUtil()
str = util.get_token_str('scturtle')
t = TBlog(CONSUMER_KEY, CONSUMER_SECRET)
t._request_handler.access_token = OAuthToken.from_string(str)

all=[]
last_id=None

tweets = t.statuses_user_timeline({'name':'自扯自蛋','count':200,'trim_user':'true'})
tweets = json.read(tweets)
for i in tweets:
    all.append(i['text'])
print 'So far:',len(all)
last_id = tweets[-1]['cursor_id']

for page in range(20):
    tweets = t.statuses_user_timeline({'name':'自扯自蛋','count':200,'trim_user':'true', 'since_id':last_id})
    tweets = json.read(tweets)
    try:
        for i in tweets:
            all.append(i['text'])
    except:
        print 'error:',tweets
        break
    print 'So far:',len(all)
    if len(tweets) == 0:
        break
    last_id = tweets[-1]['cursor_id']

print 'Got:',len(all)
with file('alltweets.txt','w') as f:
    f.write(json.write(all))

print 'Done!'

在网易的库里发现一个简单的 json 库,很好用,下面就用它而不是 python 自带的库了。

然后该提取 features 了。个人考虑,重要关键字基本上都是两个字的,比如“扯经”,“小北”,“师父”等等。于是就把所有微博按两个字一词地统计了一下词频。

dicts={}

for t in tweets:
    for i in range(len(t)-1):
        dicts[t[i:i+2]]=dicts.get(t[i:i+2],0)+1

看了一下,想要的关键词差不多都包含在前200个词里面,于是没有继续筛选,直接上前200个词作为 features。

接着该找 training set 和 test set 。从所有微博中 sample 出300个来,80%作为 training set ,20%作为 test set 。training set 和 test set 都要按照 features 转化成01向量,即微博中有某 feature 则对应向量中的哪一位为1否则为0。同时还写了个脚本人工判断结果,否则就不能 train 和 test 了。

res = [ [t,[],0] for t in tset ]

for i,r in enumerate(res):
    print 'Training set:',i
    print '<<',r[0].decode('utf8') ,'>>'
    res[i][1]=map(lambda f: 1 if f[0] in r[0] else 0, features)
    ans = raw_input('Yes? ')
    res[i][2]=1 if ans else 0

接着就是 octave 的戏份了。调用 exercise 中提供的 octave 脚本(用到的有svmTrain.m, linearKernel.m 和 svmPredict.m)训练和预测就可以了。

%% Initialization
clear ; close all; clc

fprintf('Loading training set and test set ...\n');

allX = load('t_x.mat');
ally = load('t_y.mat');

m=length(ally);

training_percent = 0.8;

X = allX(1:ceil(m * training_percent),:);
y = ally(1:ceil(m * training_percent),:);

Xtest = allX(ceil(m * training_percent):end,:);
ytest = ally(ceil(m * training_percent):end,:);

C = 0.1;

% training set 
model = svmTrain(X, y, C, @linearKernel);

p = svmPredict(model, X);
fprintf('Training Accuracy: %f\n', mean(double(p == y)) * 100);

fprintf('Program paused. Press enter to continue.\n');
pause;

% test set
p = svmPredict(model, Xtest);
fprintf('Test Accuracy: %f\n', mean(double(p == ytest)) * 100);

fprintf('Program paused. Press enter to continue.\n');
pause;

% sort and save features
fprintf('Sorting features ...\n');
[weight, idx] = sort(model.w, 'descend');

fprintf('Saving features ...\n');
out = fopen('feature.json','w');
fprintf(out,'[ %d',idx(1));
for i=2:length(idx),
	fprintf(out,', %d',idx(i));
end
fprintf(out,']');
fclose(out);

fprintf('Program paused. Press enter to continue.\n');
pause;

% predict all
fprintf('Predict for all ...\n');

Xall = load('all_x.mat');
p = svmPredict(model, Xall);

% to json
out = fopen('predict.json','w');
fprintf(out,'[%d',p(1));
for i=2:length(p),
	fprintf(out,', %d',p(i));
end
fprintf(out,']');
fclose(out);

fprintf('Done! Press enter to exit.\n')
pause;

结果按 json 格式保存出来再用 python 处理一下即可。

训练中可以看到 training set 和 test set 的准确率都达到了惊人的100%。打印出权重最高的几个 feature 可以看到符合预期,包含了“扯经”,#号,【】号,“小北”,“师父”等重要关键字:

扯经
#扯
小北
 #
】【
?】
】
已经
师父
【师
怎么
,一
,就
。】

分类后的微博也非常理想。扯经类里一溜儿的扯经。搜索非扯经类里,没有“小北”,含“师父”的一条并不是扯经,有几个含“扯经”的扯经没有分对,但是也有几个含“扯经”的确实只是一些讨论。总体效果很理想。

总结一下,python 很给力,无论是网络部分还是字符串处理部分,要是会 numpy 和 scipy 的话说不定 octave 的部分也能包办了。ML 课程的练习质量很高,既提供了很多好思路又提供了一些实际可用的脚本,不会让人上完课后还感到无从下手。还有在脚本之间用 json 格式的文件传递信息真的是非常的方便啊。

评论(0) 阅读(3209)

粒子群优化算法演示

2011年11月12日 19:14

话说想写这个好久了,看大牛做过,直到昨天在twitter上看到有人在iPad上用Codify(Codea)实现了,确实好玩,这才随便找了点资料开工。

算法意想之外的简单,模拟鸟群的行动,每个鸟根据离目标最近的那个鸟的位置调整自己的方向。实现的效果是点追随鼠标汇聚。效果没有意想中的那么好,不知是不是有bug,参数设置的不好的时候不能汇聚或者其他诡异行为,╮(╯_╰)╭。

update at 2011.11.13:设置了最小最大速率参数后,果然效果好多了

代码很简单:

# coding: utf-8

import pygame
from pygame.locals import *
from myvector import Vector, distance # 向量类
from random import random

pygame.init()
SIZE=(640, 480)
screen = pygame.display.set_mode(SIZE)
pygame.display.set_caption('pso test')
clock = pygame.time.Clock()
FPS = 60     # 帧率
speed = 0.1  # 调整速度
w = 0.7      # 惯性权重
c1 = c2 = 2  # 学习因子
minv = 10    # 最小速度
maxv = 50    # 最大速度

def sig(x): return x/abs(x)

#======================================================#
class Bird:
    def __init__(self):
        self.v=Vector(random()-0.5,random()-0.5)
        self.pos=Vector(SIZE[0]*random(),SIZE[1]*random())
        self.bestpos=self.pos
    
    def update(self, dt, mpos, gbestpos):
        global speed,w,c1,c2,minv,maxv

        # 更新速度和位置的公式
        self.v = self.v*w + ((self.bestpos-self.pos)*random()*c1 +\
                                 (gbestpos-self.pos)*random()*c2)

        # 限制最小最大速度
        self.v = Vector(sig(self.v[0]) * max(abs(self.v[0]), minv),\
                        sig(self.v[1]) * max(abs(self.v[1]), minv))
        self.v = Vector(sig(self.v[0]) * min(abs(self.v[0]), maxv),\
                        sig(self.v[1]) * min(abs(self.v[1]), maxv))

        # 更新位置
        self.pos += self.v * speed * dt /(1000.0/FPS)

        # 更新个体最优位置
        if distance(self.pos, mpos) < distance(self.bestpos, mpos):
            self.bestpos = self.pos

    def paint(self,screen):
        pygame.draw.circle(screen,(255,0,0), map(int,self.pos), 2)
#======================================================#

birds=[ Bird() for i in xrange(200)]

pygame.event.set_allowed([QUIT])
while True:
    for event in pygame.event.get():
        if event.type == QUIT:
            exit()

    dt = clock.tick(FPS)
    # 目标位置
    mpos = pygame.mouse.get_pos()
    mpos = Vector(*mpos)

    # 更新种群最优位置
    gbestpos = min(birds,key=lambda b: distance(b.pos,mpos)).pos

    screen.fill((255, 255, 255))
    # 更新每个个体
    for b in birds:
        b.update(dt ,mpos, gbestpos)
        b.paint(screen)
    pygame.display.update()

为了实现方便的写的向量类:

class Vector:
    def __init__(self,x=0,y=0):
        self.val=[float(x),float(y)]

    def __getitem__(self,key):
        return self.val[key]

    def __setitem__(self,key,value):
        self.val[key]=value

    def __str__(self):
        return '('+str(self[0])+', '+str(self[1])+')'

    def __add__(self,v):
        return Vector(self[0]+v[0],self[1]+v[1])

    def __sub__(self,v):
        return Vector(self[0]-v[0],self[1]-v[1])

    def __div__(self,n):
        return Vector(self[0]/n,self[1]/n)

    def __mul__(self,n):
        return Vector(self[0]*n,self[1]*n)

    def __iadd__(self,v):
        self[0]+=v[0]
        self[1]+=v[1]
        return self

    def __isub__(self,v):
        self[0]-=v[0]
        self[1]-=v[1]
        return self

    def __idiv__(self,n):
        self[0]/=n
        self[1]/=n
        return self

    def __imul__(self,n):
        self[0]*=n
        self[1]*=n
        return self

def distance(v1,v2):
    return ((v1[0]-v2[0])*(v1[0]-v2[0]) +\
            (v1[1]-v2[1])*(v1[1]-v2[1]))

if __name__ == '__main__':
    v1=Vector(1,2)
    v2=Vector(2,1)
    print v1+v2

评论(0) 阅读(3647)

poj 2117 求割点的tarjan算法

2011年10月20日 17:43

大四真闲,也不突击比赛了,可以拿出大把的时间慢慢看点儿算法,研究一下证明,tarjan大神的三个图论算法总算是都了解了。

割点是无向图中去掉后能把图割开的点。dfs时用dfn(u)记录u的访问时间,用low(u)数组记录u和u的子孙能追溯到的最早的节点(dfn值最小)。由于无向图的dfs只有回边和树边,且以第一次dfs时的方向作为边的方向,故有:
low=min{
dfn(u),
dfn(v),若(u,v)为回边(非树边的逆边)
low(v),若(u,v)为树边
}

顶点u是割点当且仅当其满足(1)或者(2):
(1) 若u是树根,且u的孩子数sons>1。因为没有u后,以这些孩子为根的子树间互相就不连通了,所以去掉u后得到sons个分支。
(2) 若u不是树根,且存在树边(u,v)使 low(v)>=dfn(u)。low值说明以v为根的子树不能到达u的祖先也就是去掉u后不能和原图联通,所以得到{这样的v的个数+1}个分支。

这个题是求无向图(不一定联通)中,去掉一个顶点可以形成的最多的分支数,对所有分支tarjan一下就知道了去掉哪个多了,注意孤立点的情况。求low时其实不用判断树边的逆边的情况,仔细琢磨一下,对结果没有影响,又能省很多代码。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
#define maxn 10001
typedef vector<int>::iterator it;
vector<int> g[maxn];
int dfn[maxn],low[maxn];
bool vit[maxn];
int n,idx,sons;
int ans;

void dfs(int u,bool root)
{
    vit[u]=1;
    dfn[u]=low[u]=++idx;
    int child=0;
    for(it i=g[u].begin();i!=g[u].end();++i)
    {
        int v=*i;
        if(!dfn[v])
        {
            dfs(v,false);
            low[u]=min(low[u],low[v]);
            if(root)
                sons++;
            else if(low[v]>=dfn[u])
                child++;
        }
        else low[u]=min(low[u],dfn[v]);
    }
    ans=max(ans,child+1);
}

int tarjan(int root)
{
    if(g[root].size()==0) return 0;
    memset(dfn,0,sizeof(dfn));
    ans=idx=sons=0;
    dfs(root,true);
    if(sons>1) ans=max(ans,sons);
    return ans;
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    freopen("out","w",stdout);
#endif
    int m,u,v;
    while(scanf("%d%d",&n,&m)!=EOF && n)
    {
        for(int i=0;i<n;i++)
            g[i].clear();
        while(m-->0)
        {
            scanf("%d%d",&u,&v);
            g[u].push_back(v);
            g[v].push_back(u);
        }
        memset(vit,0,sizeof(vit));
        int ma=0,total=0;
        for(int i=0;i<n;i++)
            if(!vit[i])
                total++,ma=max(ma,tarjan(i));
        printf("%d\n",total+ma-1);
    }
}

评论(1) 阅读(5939)

LCA的tarjan算法的理解

2011年10月08日 11:08

tarjan算法的步骤是(当dfs到节点u时):
1 在并查集中建立仅有u的集合,设置该集合的祖先为u
1 对u的每个孩子v:
   1.1 tarjan之
   1.2 合并v到父节点u的集合,确保集合的祖先是u
2 设置u为已遍历
3 处理关于u的查询,若查询(u,v)中的v已遍历过,则LCA(u,v)=v所在的集合的祖先
 
举例说明(非证明):


假设遍历完10的孩子,要处理关于10的请求了
取根节点到当前正在遍历的节点的路径为关键路径,即1-3-8-10
集合的祖先便是关键路径上距离集合最近的点
比如此时:
    1,2,5,6为一个集合,祖先为1,集合中点和10的LCA为1
    3,7为一个集合,祖先为3,集合中点和10的LCA为3
    8,9,11为一个集合,祖先为8,集合中点和10的LCA为8
    10,12为一个集合,祖先为10,集合中点和10的LCA为10
你看,集合的祖先便是LCA吧,所以第3步是正确的
道理很简单,LCA(u,v)便是根至u的路径上到节点v最近的点

为什么要用祖先而且每次合并集合后都要确保集合的祖先正确呢?
因为集合是用并查集实现的,为了提高速度,当然要平衡加路径压缩了,所以合并后谁是根就不确定了,所以要始终保持集合的根的祖先是正确的
关于查询和遍历孩子的顺序:
wikipedia上就是上文中的顺序,很多人的代码也是这个顺序
但是网上的很多讲解却是查询在前,遍历孩子在后,对比上文,会不会漏掉u和u的子孙之间的查询呢?
不会的
如果在刚dfs到u的时候就设置u为visited的话,本该回溯到u时解决的那些查询,在遍历孩子时就会解决掉了
这个顺序问题就是导致我头大看了很久这个算法的原因,也是絮絮叨叨写了本文的原因,希望没有理解错= =

最后,为了符合本blog风格,还是贴代码吧:

int f[maxn],fs[maxn];//并查集父节点 父节点个数
bool vit[maxn];
int anc[maxn];//祖先
vector<int> son[maxn];//保存树
vector<int> qes[maxn];//保存查询
typedef vector<int>::iterator IT;

int Find(int x)
{
    if(f[x]==x) return x;
    else return f[x]=Find(f[x]);
}
void Union(int x,int y)
{
    x=Find(x);y=Find(y);
    if(x==y) return;
    if(fs[x]<=fs[y]) f[x]=y,fs[y]+=fs[x];
    else f[y]=x,fs[x]+=fs[y];
}

void lca(int u)
{
    anc[u]=u;
    for(IT v=son[u].begin();v!=son[u].end();++v)
    {
        lca(*v);
        Union(u,*v);
        anc[Find(u)]=u;
    }
    vit[u]=true;
    for(IT v=qes[u].begin();v!=qes[u].end();++v)
    {
        if(vit[*v])
            printf("LCA(%d,%d):%d\n",u,*v,anc[Find(*v)]);
    }
}

ref:
http://purety.jp/akisame/oi/TJU/
http://en.wikipedia.org/wiki/Tarjan%27s_off-line_least_common_ancestors_algorithm
http://techfield.us/blog/2008/11/lowest_common_ancester_tarjan_alogrithm/

评论(12) 阅读(24734)

poj 2104 2761 树状数据结构

2011年10月07日 00:51

两道题都是裸裸的求区间第k大数(kth number),先给出总的区间,每个查询给出小区间[l,r]和k。
用这两道题学习了不少数据结构,感觉有些非常巧妙,记录一下。

1 平衡树

平衡树就是平衡的二叉搜索树(BST),对每个区间建立平衡树并维护size域(以此节点为根的树的节点数),便可很容易的实现select(找第k大数)和rank(查某数是第几大)两个操作。因为要反复针对区间建树,所以先对区间排好序会比较好,这样后一个区间的树就是对前一个区间的树的增删操作了。

1.1 Treap
Tree就是Tree(BST)+Heap。 首先要满足BST的性质,节点的key要左小于中小于右。然后每个节点加上了一个附加域fix,这个域没什么意义,就是个随机数,但是需要保持树上的这个域符合堆的性质,比如最小堆,中小于左右。 怎么能又保持BST又保持Heap呢?这就要靠一个神奇的BST的操作了!BST中的左旋和右旋是不会破坏BST的性质的!于是先保证树是BST,再用左旋和右旋把整棵树调整成Heap。于是就可以看到添加节点时,先加到最下面,然后,扭啊扭啊扭上去,删除时把节点扭啊扭啊扭下去再删掉~ 显然,fix和堆的存在就是为了搅乱整棵树的,这样就不会对特定的数据恶化了。而且扭的过程中还可以降低树的高度,所以Treap是随机平衡的树,各操作均摊复杂度才是O(nlogn),不像AVL和红黑树那样是近似绝对平衡的树。但是Treap编程难度低啊,所以很好用,我觉得,这个把BST和Heap结合起来的扭来扭去的树和Dancing links一样优美~ 对size域的维护放到旋转时就可以了。具体资料可见《随机平衡二叉查找树 Treap 的分析与应用 by 郭家宝》。

#include <cstdio>
#include <cstdlib>

#define nil 0
#define MAX 1000010
int key[MAX],left[MAX],right[MAX],size[MAX],fix[MAX];
int root,node;

inline void Left_Rotate(int &x) {
    int k = right[x];
    right[x] = left[k];
    left[k] = x;
    size[k] = size[x];
    size[x] = size[left[x]] + size[right[x]] + 1;
    x = k;
}

inline void Right_Rotate(int &y) {
    int k = left[y];
    left[y] = right[k];
    right[k] = y;
    size[k] = size[y];
    size[y] = size[left[y]] + size[right[y]] + 1;
    y = k;
}

void Insert(int &T,int v) {
    if(T == nil) {
        key[T = ++node]= v;
        fix[T]=rand(); size[T]=1;
    } else {
        size[T]++;
        if(v < key[T]) {
            Insert(left[T], v);
            if(fix[left[T]]<fix[T])
                Right_Rotate(T);
        } else {
            Insert(right[T], v);
            if(fix[right[T]]<fix[T])
                Left_Rotate(T);
        }
    }
}

void Delete(int &T,int v) {
    if(!T) return;
    size[T]--;
    if(v == key[T]){
        if(!left[T] || !right[T])
            T=left[T]+right[T];
        else if(fix[left[T]]<fix[right[T]]) {
            Right_Rotate(T);
            Delete(right[T],v);
        } else {
            Left_Rotate(T);
            Delete(left[T],v);
        }
    } else if (v < key[T])
        Delete(left[T],v);
    else
        Delete(right[T],v);
}

int Select(int &T,int k)
{
    int r = size[left[T]] + 1;
    if(k == r)
        return key[T];
    else if(k < r)
        return Select(left[T], k);
    else
        return Select(right[T], k - r);
}

int arr[100010];
typedef struct {int l,r,k,org;} Q;
Q q[50010];
int ans[50010];

int cmp(const void *a,const void *b)
{
    Q *qa=(Q*)a,*qb=(Q*)b;
    if(qa->l!=qb->l) return qa->l-qb->l;
    else return qa->r-qb->r;
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    freopen("out","w",stdout);
#endif
    //init
    //srand(1984);
    root=0;
    int n,m; scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d",arr+i);
    for(int i=1;i<=m;i++)
    {
        int l,r,k;
        scanf("%d%d%d",&l,&r,&k);
        q[i].l=l; q[i].r=r; q[i].k=k; q[i].org=i;
    }
    qsort(q+1,m,sizeof(Q),cmp);
    int lastl=1,lastr=0;
    for(int t=1;t<=m;t++)
    {
        int l=q[t].l,r=q[t].r,k=q[t].k;
        if(lastr<l)
        {
            for(int i=lastl;i<=lastr;i++)
                Delete(root,arr[i]);
            for(int i=l;i<=r;i++)
                Insert(root,arr[i]);
        }
        else if(lastr<=r)
        {
            for(int i=lastl;i<l;i++)
                Delete(root,arr[i]);
            for(int i=lastr+1;i<=r;i++)
                Insert(root,arr[i]);
        }
        else
        {
            for(int i=lastl;i<l;i++)
                Delete(root,arr[i]);
            for(int i=r+1;i<=lastr;i++)
                Delete(root,arr[i]);
        }
        lastl=l;lastr=r;
        ans[q[t].org]=Select(root,k);
    }
    for(int i=1;i<=m;i++)
        printf("%d\n",ans[i]);
    
}


1.2 Splay
Splay和Treap很像,也是扭来扭去的。Splay树的独特之处就是splay操作,把某一个节点扭到root处,在扭的过程中同样可以搅乱整棵树,降低树高。Splay操作也仅仅是用单旋和双旋两种操作,应对节点和节点父亲及节点的父亲的父亲之间的多种情况,快速的把节点向上提升。 Splay树中的插入操作是在普通BST插入操作后,把插入节点splay到根。删除操作有多种实现,我是先把要删除节点splay到根,再把左子树中最大的节点splay到根,然后,其右节点即为要删除的节点,而此时他只有右子树了,所以很容易就删除了。Splay也是随机的平衡树,可以证明的是splay操作均摊复杂度为O(1),所以各操作的均摊复杂度为O(nlogn)了。size域的维护在单旋和删除操作时维护即可。资料参见《伸展树的基本操作与应用 by 杨思雨》、《伸展树操作详解  By Ma Shuo》和《The Magical Splay by sqybi》。

#include <cstdio>
#include <cstdlib>

struct Node 
{
    Node *pre,*ch[2]; int size,key;
    Node(){pre=ch[0]=ch[1]=NULL;size=1;}
} node[1000100];
int tot=0;
Node *root;

void print(Node *x)
{
    if(!x) return;
    print(x->ch[0]);
    printf("%d ",x->key);
    print(x->ch[1]);
}

void Rotate(Node *x)
{
    bool c=(x->pre->ch[0]==x);//0:left(zag) 1:right(zig)
    Node *y=x->pre; 
    y->ch[!c]=x->ch[c];
    if(x->ch[c]!=NULL) x->ch[c]->pre=y; 
    x->pre = y->pre; 
    if(y->pre!=NULL) 
        y->pre->ch[y->pre->ch[0]!=y]=x; 
    x->ch[c]=y; y->pre=x; 
    x->size=y->size; y->size=1;
    for(int i=0;i<2;i++)
        if(y->ch[i]) y->size+=y->ch[i]->size;
    if(y==root) root=x;
}

// up x until x->pre = f, f==NULL means root 
void Splay(Node *x,Node *f=NULL)
{
    while(x->pre!=f)
        if(x->pre->pre==f)
            Rotate(x);
        else
        {
            Node *y=x->pre,*z=y->pre;
            if((z->ch[0]==y)==(y->ch[0]==x))
                    Rotate(y), Rotate(x);
                else
                    Rotate(x), Rotate(x);
        }
    if(f==NULL) root=x;
}

Node* Search(Node *x,int k)//return the last found
{
    if(!x) return NULL;
    while(k != x->key)
    {
        bool c=(k > x->key);
        if(x->ch[c]==NULL) break;
        x=x->ch[c];
    }
    return x;
}

void Insert(int k)
{
    Node *x=root,*y=NULL;
    while(x)// && k != x->key)
    {
        x->size++;
        y=x; x=x->ch[k > x->key];
    }
    Node *z=&node[tot++]; z->key=k;
    z->pre=y;
    if(y==NULL)
        root=z;
    else
        y->ch[k > y->key]=z;
    Splay(z);
}

Node* Extreme(Node *x,bool c)
{
    while(x && x->ch[c]!=NULL)
        x=x->ch[c];
    return x;
}

bool Delete(int k)
{
    Node *x=Search(root,k);
    if(!x || x->key!=k) return false;
    Splay(x);//make x the root
    if(!x->ch[0] || !x->ch[1])// one or none child
    {
        bool c=(x->ch[0])?0:1;
        if(x->ch[c]) x->ch[c]->pre=NULL;
        root=x->ch[c];
    }
    else //combine x's childs
    {
        Node *l=x->ch[0],*r=x->ch[1];
        Splay(Extreme(l,1));
        root->ch[1]=r; r->pre=root;
        root->size--;
    }
    //delete(x);
    return true;
}

int Select(Node *x,int k)
{
    int r=1+((x->ch[0])?x->ch[0]->size:0);
    if(k==r)
        return x->key;
    else if(k < r)
        return Select(x->ch[0],k);
    else
        return Select(x->ch[1],k-r);
}


int arr[1000500];
struct Q{int l,r,k,org;} q[50050];
int ans[50050];

int cmp(const void *a,const void *b)
{
    Q *qa=(Q*)a,*qb=(Q*)b;
    if(qa->l!=qb->l) return qa->l-qb->l;
    else return qa->r-qb->r;
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    freopen("out","w",stdout);
#endif
    int n,m; scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d",arr+i);
    for(int i=1;i<=m;i++)
    {
        int l,r,k;
        scanf("%d%d%d",&l,&r,&k);
        q[i].l=l; q[i].r=r; q[i].k=k; q[i].org=i;
    }
    qsort(q+1,m,sizeof(Q),cmp);
    int lastl=1,lastr=0;
    for(int t=1;t<=m;t++)
    {
        int l=q[t].l,r=q[t].r,k=q[t].k;
        if(lastr<l)
        {
            for(int i=lastl;i<=lastr;i++)
                Delete(arr[i]);
            for(int i=l;i<=r;i++)
                Insert(arr[i]);
        }
        else if(lastr<r)
        {
            for(int i=lastl;i<l;i++)
                Delete(arr[i]);
            for(int i=lastr+1;i<=r;i++)
                Insert(arr[i]);
        }
        else
        {
            for(int i=lastl;i<l;i++)
                Delete(arr[i]);
            for(int i=r+1;i<=lastr;i++)
                Delete(arr[i]);
        }
        lastl=l;lastr=r;
        ans[q[t].org]=Select(root,k);
        //printf("l:%d r:%d k:%d :: ",l,r,k);
        //print(root);puts("");
    }
    for(int i=1;i<=m;i++)
        printf("%d\n",ans[i]);
}

1.3 SBT(Size Balanced Tree)
名字大亮,"傻叉树" or "Super BT"树。如其本名,是直接用size域实现近似绝对平衡的树。较复杂,不介绍了,参见《Size Balanced Tree by Chen Qifeng(此树作者)》、《由二叉查找树到容均树 by 田劲锋》(写的大好啊,我就是用这个里面的代码做模板的)。

#include <cstdio>
#include <cstdlib>

int arr[100010];
typedef struct {int l,r,k,org;} Q;
Q q[50100];
int ans[50100];

#define nil 0
const int MAX = 100010;
int key[MAX], left[MAX], right[MAX], size[MAX];
int T, node;
int record; // This is used for the commented Delete
inline void Left_Rotate(int &x) {
    int k = right[x];
    right[x] = left[k];
    left[k] = x;
    size[k] = size[x];
    size[x] = size[left[x]] + size[right[x]] + 1;
    x = k;
}
inline void Right_Rotate(int &y) {
    int k = left[y];
    left[y] = right[k];
    right[k] = y;
    size[k] = size[y];
    size[y] = size[left[y]] + size[right[y]] + 1;
    y = k;
}
void Maintain(int &T, bool flag);
void Insert(int &T, int v) {
    if(T == nil) {
        key[T = ++node] = v;
        size[T] = 1;
    } else {
        size[T]++;
        if(v < key[T])
            Insert(left[T], v);
        else
            Insert(right[T], v);
        Maintain(T, v >= key[T]);
    }
}
int Delete(int &T, int v) {
    size[T]--;
    if( (v == key[T]) || (v < key[T] && left[T] == nil) || (v > key
                [T] && right[T] == nil) ) {
        int r = key[T];
        if(left[T] == nil || right[T] == nil)
            T = left[T] + right[T];
        else
            key[T] = Delete(left[T], key[T] + 1);
        return r;
    } else {
        if(v < key[T])
            return Delete(left[T], v);
        else
            return Delete(right[T], v);
    }
}
void Maintain(int &T, bool flag) {
    if(flag == false) {
        if(size[left[left[T]]] > size[right[T]])
            Right_Rotate(T);
        else {
            if(size[right[left[T]]] > size[right[T]]) {
                Left_Rotate(left[T]);
                Right_Rotate(T);
            } else return;
        }
    } else {
        if(size[right[right[T]]] > size[left[T]])
            Left_Rotate(T);
        else {
            if(size[left[right[T]]] > size[left[T]]) {
                Right_Rotate(right[T]);
                Left_Rotate(T);
            } else return;
        }
    }
    Maintain(left[T], false);
    Maintain(right[T], true);
    Maintain(T, true);
    Maintain(T, false);
}
int Select(int T, int k) {
    int r = size[left[T]] + 1;
    if(k == r)
        return key[T];
    else if(k < r)
        return Select(left[T], k);
    else
        return Select(right[T], k - r);
}
int cmp(const void *a,const void *b)
{
    Q *qa=(Q*)a,*qb=(Q*)b;
    if(qa->l!=qb->l) return qa->l-qb->l;
    else return qa->r-qb->r;
}
int main() 
{
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    freopen("out","w",stdout);
#endif
    int n,m,lastl=1,lastr=0;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d",arr+i);
    //sort query
    for(int i=1;i<=m;i++)
    {
        int l,r,k; scanf("%d%d%d",&l,&r,&k);
        q[i].l=l; q[i].r=r; q[i].k=k; q[i].org=i;
    }
    qsort(q+1,m,sizeof(Q),cmp);
    //do
    for(int t=1;t<=m;t++)
    {
        int l=q[t].l,r=q[t].r,k=q[t].k;
        if(lastr<l)
        {
            for(int i=lastl;i<=lastr;i++)
                Delete(T,arr[i]);
            for(int i=l;i<=r;i++)
                Insert(T,arr[i]);
        }
        else if(lastr<=r)
        {
            for(int i=lastl;i<l;i++)
                Delete(T,arr[i]);
            for(int i=lastr+1;i<=r;i++)
                Insert(T,arr[i]);
        }
        else
        {
            for(int i=lastl;i<l;i++)
                Delete(T,arr[i]);
            for(int i=r+1;i<=lastr;i++)
                Delete(T,arr[i]);
        }
        lastl=l;lastr=r;
        ans[q[t].org]=Select(T,k);
    }
    for(int i=1;i<=m;i++)
        printf("%d\n",ans[i]);
}

2 线段树

wr大牛说线段树也是一种平衡树,不过至少在这个问题的应用上还是有很大区别的。线段树的解法不再是构造每个查询区间的树了,而是构造好整个区间的(线段)树,再从上到下查询。

2.1 划分树
据说是专门用来解决kth number问题的。划分树的原理基于快排,线段树的每一层都是快排的一级,最上层就是整个数组,从上往下建立。为了保持绝对平衡,使用中位数做支点,左右平均,所以左右两边可能有相同的元素,在建树时要注意。建树时同时建立好num_left数组,记录每个线段树节点的区间中,在其左侧小于等于某元素的元素数。整个查找过程是从上到下分析num_left并选择向左还是向右的过程。具体参见这里

#include <cstdio>
#include <cstdlib>
#include <algorithm>
using namespace std;
#define max 1000010

struct Tree { int l,r; } tree[max*3];
int leftsum[30][max],seg[30][max],arr[max];

void parti_build(int r,int s,int t,int d)
{
    tree[r].l=s;tree[r].r=t;
    if(t==s) return;
    int mid=(s+t)>>1;
    int lsame=mid-s+1;
    for(int i=s;i<=t;i++)
        if(seg[d][i]<arr[mid])
            lsame--;
    int lpos=s,rpos=mid+1;
    for(int i=s;i<=t;i++)
    {
        leftsum[d][i] = (i==s)?0:leftsum[d][i-1];
        if(seg[d][i]<arr[mid])
        {
            leftsum[d][i]++;
            seg[d+1][lpos++]=seg[d][i];
        }
        else if(seg[d][i]>arr[mid])
            seg[d+1][rpos++]=seg[d][i];
        else
        {
            if(lsame)
            {
                lsame--; leftsum[d][i]++;
                seg[d+1][lpos++]=seg[d][i];
            }
            else
                seg[d+1][rpos++]=seg[d][i];
        }
    }
    parti_build(r*2,s,mid,d+1);
    parti_build(r*2+1,mid+1,t,d+1);
}

int find(int r,int s,int t,int d,int k)
{
    int tl=tree[r].l,tr=tree[r].r;
    if(tr==tl) return seg[d][s];
    int s_left_sum=(s==tl)?0:leftsum[d][s-1];
    int s_to_t_sum=leftsum[d][t]-s_left_sum;
    if(s_to_t_sum >=k)
        return find(r*2,tl+s_left_sum,tl+leftsum[d][t]-1,d+1,k);
    else
    {
        int mid=(tl+tr)>>1;
        int s_right_sum=s-tl-s_left_sum;
        int s_to_t_r_sum=t-s+1-s_to_t_sum;
        return find(r*2+1,mid+s_right_sum+1,mid+s_right_sum+s_to_t_r_sum,d+1,k-s_to_t_sum);
    }
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    freopen("out","w",stdout);
#endif
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
    {
        scanf("%d",arr+i);
        seg[1][i]=arr[i];
    }
    sort(arr+1,arr+1+n);
    parti_build(1,1,n,1);
    while(m--)
    {
        int s,t,k;
        scanf("%d%d%d",&s,&t,&k);
        printf("%d\n",find(1,s,t,1,k));

    }
}

2.2 归并树
和划分树不同,归并树用的是归并排序的思想,每一层都是归并排序的一级,从下到上构建,每一个节点区间内都是排好序的(和划分树不同)。线段树是用来查某数的rank的,比较典型的线段树查询操作,不用附加域,而是用二分的思想,因为节点区间内都是排好序的,注意因为有重复值的可能所以二分求的是区间内小于某数的数的个数。对每个查询的结果是用其在所给小区间rank值二分的,注意二分是求使rank值为k的最大的数,原因不太好想。介绍可看这里,代码可看这里

#include <cstdio>
#include <cstdlib>

#define MAX 1000010

struct Tree {int l,r,lev;} tree[MAX*3];
int seg[22][MAX];
int n,m;
int arr[MAX],max,min;

void build(int root,int lev,int l,int r)
{
    tree[root].l=l;tree[root].r=r;tree[root].lev=lev;
    if(l==r)
    {
        seg[lev][l]=arr[l];
        return;
    }
    
    int mid=(l+r)/2;
    build(root*2,lev+1,l,mid);
    build(root*2+1,lev+1,mid+1,r);

    int p=l,p1=l,p2=mid+1;
    while(p<=r)
    {
        if(p1>mid)
            seg[lev][p++]=seg[lev+1][p2++];
        else if(p2>r)
            seg[lev][p++]=seg[lev+1][p1++];
        else if(seg[lev+1][p1]<=seg[lev+1][p2])
            seg[lev][p++]=seg[lev+1][p1++];
        else
            seg[lev][p++]=seg[lev+1][p2++];
    }
}

//the count in a that a[i]<x
int count(int *a,int l,int r,int x)
{
    int m,orgl=l;
    if(a[l]>=x) return 0;
    else if(a[r]<x) return r-l+1;
    while(l+1<r) //assert: a[l]<x && x<=a[r]
    {
        m=(l+r)>>1;
        if(a[m]<x) l=m; else r=m;
    }
    return l-orgl+1;
}

int rank(int root,int l,int r,int x)
{
    if(tree[root].l==l && tree[root].r==r)
        return count(seg[tree[root].lev],l,r,x);

    int mid=(tree[root].l+tree[root].r)/2;

    if(r<=mid)
        return rank(root*2,l,r,x);
    else if(l>mid)
        return rank(root*2+1,l,r,x);
    else 
        return rank(root*2,l,mid,x)+
               rank(root*2+1,mid+1,r,x);
}

int query(int l,int r,int k)
{
    int s=seg[1][0],t=seg[1][n-1]+1,m;
    while(s+1<t) // [s,t) <= ans
    {
        m=(s+t)/2;
        int rk=rank(1,l,r,m)+1;
        if(rk>k) t=m;
        else s=m;
    }
    return s;
}

int main()
{
#ifndef ONLINE_JUDGE
    freopen("in","r",stdin);
    freopen("out","w",stdout);
#endif
    scanf("%d%d",&n,&m);
    for(int i=0;i<n;i++)
        scanf("%d",arr+i);
    build(1,1,0,n-1);
    int l,r,k;
    while(m--)
    {
        scanf("%d%d%d",&l,&r,&k);
        printf("%d\n",query(l-1,r-1,k));
    }
}

评论(1) 阅读(3370)