JOIN
Get Time
forums   
Search | Watch Thread  |  My Post History  |  My Watches  |  User Settings
View: Flat (newest first)  | Threaded  | Tree
Previous Thread  |  Next Thread
pku 1986 Distance Queries TLEEE | Reply
I need help :P a lot of TLE
I tried to solve this using a segment tree in the following way.
First, I read the tree and do a bfs in order to relabel the verteces such that a vertex with a less label than other is in the same depth level or above.
Then, I do a dfs travelsal, as it is depicted in the tutorial and create to arrays, the first array recvertex has the labels of the verteces visited in the traversal, and the vector recedge has the values of the edges visited, such that if an edge is visited in a forward way, its value is positive, but if the edge is in a backtracking way, its value is negative. An also, the first value of the recedge array is 0, because it has no cost to reach the vertex 0. Besides,I have an array occurs that has, for each vertex, its first occurence in the recvertex array.
After that, I create a segment tree for recvertex, which will give the lower common ancestor of two verteces.
And finally, for each query, I read the two verteces, get their label values, get their lower common ancestor, and calculate the distance between both verteces.

Well, I have a lot of tle, I have tested with some input data from usaco and it seems to solve it well, but the tle is there.

I copy my code here

#include<cstdio>
#include<iostream>
#include<vector>
#include<algorithm>
#include<cstdio>
#include<queue>
using namespace std;
typedef int ll;
#define N 500000
#define NN 100000
vector<vector<int> > adj2;
vector<vector<int> > wadj2;
int recvertex[N],recedge[N],occurs[NN],segmentmin[N],bit[NN];
int recvertexsz=0,recedgesz=0;
 
void updatemin(int pos,int by,int node=0,int lo=0,int hi=recvertexsz-1){
 segmentmin[node]=min(segmentmin[node],by);
  if(lo<hi){
   int mid=(lo+hi)/2;
   if(mid>=pos)
    updatemin(pos,by,node*2+1,lo,mid); 
   else
    updatemin(pos,by,node*2+2,mid+1,hi);
  }
}
 
int querymin(int from,int to,int node=0,int lo=0,int hi=recvertexsz-1){
 if(from>to) return 1<<30;
 if(from==lo&&to==hi) return segmentmin[node];
 int mid=(lo+hi)/2;
 return min(querymin(from,min(to,mid),node*2+1,lo,mid),
	    querymin(max(mid+1,from),to,node*2+2,mid+1,hi));
}
 
void dfs(int u){
 recvertex[recvertexsz++]=u;
 if(u==0) recedge[recedgesz++]=0;
 occurs[u]=recvertexsz-1;
 
 for(int i=0;i<adj2[u].size();i++) if(occurs[adj2[u][i]]==-1){
  int v=adj2[u][i];
  recedge[recedgesz++]=wadj2[u][i];//forward edge
  dfs(v);
  recedge[recedgesz++]=-wadj2[u][i];//backward edge
  recvertex[recvertexsz++]=u;
 }
}
 
int main(){
 int n,m;
 cin>>n>>m;
 vector<vector<int> > adj(n);//adjacence list
 vector<vector<int> > wadj(n);//weigth adjacence list
 for(int i=0;i<n;i++) occurs[i]=-1;
 
 adj2.resize(n);
 wadj2.resize(n);
 
 for(int i=0;i<m;i++){
  int a,b,c; char d;
  scanf("%d %d %d %c",&a,&b,&c,&d);
  a--; b--;
  adj[a].push_back(b);
  adj[b].push_back(a);
  wadj[a].push_back(c);
  wadj[b].push_back(c);
 }
 
 int label[n];//labels
 for(int i=0;i<n;i++) label[i]=-1;
 
 //bfs that generates a depth label for each vertex
 int nlab=0;
 int q[NN]; int front=0,back=0; 
 q[back++]=0;
 label[0]=nlab++;
 
 while(front!=back){
  int u=q[front++];
  for(int i=0;i<adj[u].size();i++) if(label[adj[u][i]]==-1){
   label[adj[u][i]]=nlab++;
   q[back++]=adj[u][i];
  }
 }
 
 //relabel each vertex and generate two new adjacence lists
 for(int i=0;i<n;i++) for(int j=0;j<adj[i].size();j++){
  adj2[label[i]].push_back(label[adj[i][j]]);
  wadj2[label[i]].push_back(wadj[i][j]);
 }
 
 //do the dfs traversal
 dfs(0);
 
 for(int i=0;i<4*recvertexsz;i++) segmentmin[i]=1<<30;
 for(int i=0;i<recvertexsz;i++) bit[i]=0; // bit has the accumulate sum of recedge
 
 for(int i=0;i<recvertexsz;i++) updatemin(i,recvertex[i]);//create the segment tree
 for(int i=1;i<recvertexsz;i++) bit[i]=bit[i-1]+recedge[i];
 //for(int i=0;i<n;i++) cout<<occurs[i]<<" ";cout<<endl;
 int k; cin>>k;
 
 //each query
 while(k--){
   int a,b; scanf("%d %d",&a,&b); a--;b--;
   a=label[a];
   b=label[b];
   if(occurs[a]>occurs[b]) swap(a,b);
   int low=querymin(occurs[a],occurs[b]);
   
   ll dist=bit[occurs[a]]-2*bit[occurs[low]]+bit[occurs[b]];
   printf("%d\n",dist);
 }
}


I need your help, please xD. This is an important problem for learning and I will be very happy if I solve it.
Thanks in advance =D
RSS