【HDU-1147】Pick-up sticks【线段相交】

【HDU-1147】Pick-up sticks【线段相交】

问题链接:https://vjudge.net/problem/HDU-1147

Solution:

这道题的时间限制十分严格,一开始手动实现小顶堆+扫描线,一直TLE:

#include <iostream>
#include <set>
#include <map>
#include <vector>

// 扫描线版本,TLE。

using namespace std;

inline double maxi(double a, double b) {return a > b ? a : b;}
inline double mini(double a, double b) {return a < b ? a : b;}
inline int maxi(int a, int b) {return a > b ? a : b;}
inline int mini(int a, int b) {return a < b ? a : b;}

struct Point {
    double x, y;
    int id;
    bool operator < (const Point &p) const {
        return x < p.x;
    }
    Point operator - (const Point &p) const {
        Point ret = {x - p.x, y - p.y};
        return ret;
    }
    double operator ^ (const Point &v) const {
        return x * v.y - y * v.x;
    }
};

const int LIM = 1e5 + 10;
vector<Point> ps;
inline void swap(int &a, int &b) {
    int tmp = a;
    a = b;
    b = tmp;
}

struct Cmp {
    bool operator()(int a, int b) {
        return ps[a].x < ps[b].x;
    }
};

class MinHeap {
private:
    int heap_pos;
    int data[LIM];
    int cap;
    Cmp cmp;
public:
    MinHeap() : cap(LIM), heap_pos(1) {}
    int extractMin();
    bool doInsert(int key);
    void minHeapify(int idx);
    int size();
};

int MinHeap::extractMin() {
    if (heap_pos == 1) {
        return INT_MAX;
    } else if (heap_pos == 2) {
        heap_pos--;
        return data[1];
    } else {
        int ret = data[1];
        heap_pos--;
        data[1] = data[heap_pos];
        minHeapify(1);
        return ret;
    }
}

bool MinHeap::doInsert(int key) {
    if (heap_pos == cap)
        return false;
    data[heap_pos] = key;
    for (int i = heap_pos; i > 1 && cmp(data[i], data[i >> 1]); i >>= 1)
        swap(data[i >> 1], data[i]);
    heap_pos++;
    return true;
}

void MinHeap::minHeapify(int idx) {
    int i = idx;
    if ((idx << 1) < heap_pos && cmp(data[idx << 1], data[i]))
        i = idx << 1;
    if ((idx << 1 | 1) < heap_pos && cmp(data[idx << 1 | 1], data[i]))
        i = idx << 1 | 1;
    if (i != idx) {
        swap(data[i], data[idx]);
        minHeapify(i);
    }
}

int MinHeap::size() {
    return heap_pos - 1;
}

enum {CLOCK, CONTER, COL};
int getOrientation(Point a, Point b, Point c) {
    Point v1 = b - a;
    Point v2 = c - a;
    double cross = v1 ^ v2;
    if (cross < 0) return CLOCK;
    else if (cross == 0) return COL;
    else return CONTER;
}

bool onSegment(Point s, Point e, Point p) {
    return p.x >= mini(s.x, e.x) && p.x <= maxi(s.x, e.x)
        && p.y >= mini(s.y, e.y) && p.y <= maxi(s.y, e.y);
}

bool isIntersect(Point s1, Point e1, Point s2, Point e2) {
    int oa = getOrientation(s1, s2, e1);
    int ob = getOrientation(s1, e2, e1);
    int oc = getOrientation(s2, s1, e2);
    int od = getOrientation(s2, e1, e2);
    if (oa != ob && oc != od) return true;
    if (oa == COL && onSegment(s1, e1, e2)) return true;
    if (ob == COL && onSegment(s1, e1, s2)) return true;
    if (oc == COL && onSegment(s2, e2, e1)) return true;
    if (od == COL && onSegment(s2, e2, s1)) return true;
    return false;
}

int main(void) {
    int n;
    Point s, e, crt;
    while (~scanf("%d", &n) && n) {
        ps.clear();
        MinHeap mh;
        for (int i = 0; i < n; i++) {
            scanf("%lf%lf%lf%lf", &s.x, &s.y, &e.x, &e.y);
            s.id = e.id = i + 1;
            ps.push_back(s);
            ps.push_back(e);
            mh.doInsert(i * 2);
            mh.doInsert(i * 2 + 1);
        }
        set<int> active;
        set<int> ans;
        vector<int> toDel;
        while (mh.size()) {
            crt = ps[mh.extractMin()];
            if (!toDel.empty() && crt.x != ps[toDel[0]].x) {
                for (int x : toDel) {
                    active.erase(x);
                }
                toDel.clear();
            }
            if (active.count(crt.id)) toDel.push_back(crt.id);
            else active.insert(crt.id);
            for (int x : active) {
                for (int y : active) {
                    if (x > y) continue;
                    else if (x == y) {
                        ans.insert(x);
                        continue;
                    }
                    if (isIntersect(ps[(x - 1) * 2], ps[(x - 1) * 2 + 1],
                        ps[(y - 1) * 2], ps[(y - 1) * 2 + 1])) {
                        int over = maxi(x, y);
                        int suf = mini(x, y);
                        if (ans.count(suf)) ans.erase(suf);
                        ans.insert(over);
                    } else {
                        ans.insert(x);
                        ans.insert(y);
                    }
                }
            }
        }
        printf("Top sticks: ");
        set<int>::iterator ite;
        int i;
        for (i = 0, ite = ans.begin(); ite != ans.end(); ite++, i++) {
            if (i) printf(", ");
            printf("%d", *ite);
        }
        printf(".\n");
    }
    return 0;
}

扫描线复杂度应该是\(O(nlogn)\),最差\(O(n^2)\)的,但是需要维护一个激活集合,这道题时间卡得紧,需要一个更符合这道题的解法。

这道题的解法其实也挺简单,表面上我们是用了一个\(O(n^2)\)的算法,但是实际上,最优能够降到\(O(n)\),可能这就是用这种方法能够AC的原因,简单说一下,其实就是一个筛选法,另开一个数组,筛选掉的直接扔掉,降低里层循环规模,维护一个紧密型线性存储,充分利用CPU的Cache,速度会有明显提升。

#include <iostream>

using namespace std;

struct Seg {
    double x1, y1, x2, y2;
    int id;
};

inline double crossPro(double x1, double y1, double x2, double y2) {
    // printf("(%f, %f) ^ (%f, %f)\n", x1, y1, x2, y2);
    return x1 * y2 - y1 * x2;
}

const int LIM = 1e5 + 10;
Seg ps[LIM];
Seg ans[LIM];
int gblcnt;

inline double maxi(double a, double b) {return a > b ? a : b;}
inline double mini(double a, double b) {return a < b ? a : b;}
inline bool onSegment(Seg s, double px, double py) {
    return maxi(s.x1, s.x2) >= px && px >= mini(s.x1, s.x2)
        && maxi(s.y1, s.y2) >= py && py >= mini(s.y1, s.y2);
}


inline int sgn(int x) {
    if (x < 0) return -1;
    else if (x == 0) return 0;
    return 1;
}
bool isIntersect(Seg s1, Seg s2) {
    double oa = crossPro(s1.x2 - s1.x1, s1.y2 - s1.y1, s2.x1 - s1.x1, s2.y1 - s1.y1);
    double ob = crossPro(s1.x2 - s1.x1, s1.y2 - s1.y1, s2.x2 - s1.x1, s2.y2 - s1.y1);
    double oc = crossPro(s2.x2 - s2.x1, s2.y2 - s2.y1, s1.x1 - s2.x1, s1.y1 - s2.y1);
    double od = crossPro(s2.x2 - s2.x1, s2.y2 - s2.y1, s1.x2 - s2.x1, s1.y2 - s2.y1);
    // cout << oa << ob << oc << od << endl;
    if (sgn(oa) != sgn(ob) && sgn(oc) != sgn(od)) return true;
    if (!oa && onSegment(s1, s2.x2, s2.y2)) return true;
    if (!ob && onSegment(s1, s2.x1, s2.y1)) return true;
    if (!oc && onSegment(s2, s1.x2, s1.y2)) return true;
    if (!od && onSegment(s2, s1.x1, s1.y1)) return true;
    return oa * ob < 0 && oc * od < 0;
}

int main(void) {
    int n;
    while (~scanf("%d", &n) && n) {
        gblcnt = n;
        for (int i = 0; i < n; i++) {
            scanf("%lf%lf%lf%lf", &ps[i].x1, &ps[i].y1, &ps[i].x2, &ps[i].y2);
            ps[i].id = i + 1;
            ans[i] = ps[i];
        }
        for (int i = n - 1; i >= 0; i--) {
            int tmpcnt = 0;
            for (int j = 0; j < gblcnt; j++) {
                if (ps[i].id <= ans[j].id || !isIntersect(ps[i], ans[j])) {
                    ans[tmpcnt++] = ans[j];
                }
            }
            gblcnt = tmpcnt;
        }
        printf("Top sticks: ");
        for (int i = 0; i < gblcnt; i++)
            if (i) printf(", %d", ans[i].id);
            else printf("%d", ans[i].id);
        printf(".\n");
    }
    return 0;
}