博客

DFS & BFS Example

DFS
easy:
Oil Deposits https://vjudge.net/problem/HDU-1241#author=prayerhgq

normal:
Connect https://vjudge.net/problem/CodeForces-1130C#author=XMIWmayIsWorld
大臣的旅费 http://www.dotcpp.com/oj/problem1438.html
Sum in the tree https://vjudge.net/problem/CodeForces-1098A

BFS
easy:
迷宫问题 https://vjudge.net/problem/POJ-3984
Red and Black https://vjudge.net/problem/POJ-1979#author=s19435631
学霸的迷宫 http://www.dotcpp.com/oj/problem1923.html

normal:
Eight II https://vjudge.net/problem/HDU-3567

 

DFS

int dfs(int crt) {
    // 进入状态前
    if (/*如果已经记录了数据*/) // 返回存储了的数据(记忆化)

    // 处理状态中
    // 检查是不是答案
    // 可行性检查、最优性检查等

    // 离开状态后
    // 寻找之后的状态
    for (/*iterate every choice*/) {
        if (/*可达性判断*/) {
            /*vis等标记处理*/
            // 进入相邻的状态
            dfs(next_state);
            /*无效标记还原(backtrack)*/
        }
    }
}

 

 

BFS

void bfs() {
    queue<T> q;
    q.push(/*构建初始队列*/);

    while (!q.empty()) {
        // 状态前,获得当前状态属性等操作
        crt = q.front();
        q.pop();

        // 处理状态中
        if (/*达到解*/) {
            /*找到解后的操作*/
        }

        // 状态后,将相邻状态加入队列
        for (/*枚举每一种相邻状态*/) {
            if (/*可达性、最优性、有效性等的检测*/) {
                /*更新标记*/
                q.push(/*新状态*/);
            }
        }

    }
}

 

2018Java小作业

自己写的,由于题目要求十分基础,整个作业下来基本就是体力活了。

下面是代码:

import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.KeyEvent;
import java.awt.event.KeyListener;
import java.io.*;
import java.nio.file.NoSuchFileException;
import java.text.SimpleDateFormat;
import java.util.*;
import java.awt.*;
import javax.swing.*;
import javax.swing.text.AbstractDocument;

public class Additor {
    public static void main(String[] args) {
        new MainFrame();
    }

    /**
     * 控制台调试用代码
     */
    private static void testCode() {
        FixedPointNumber fpn = new FixedPointNumber(-12300, 3);
        System.out.println(fpn);
        FixedPointNumber b = new FixedPointNumber(246, 1);
        System.out.println(fpn.subtract(b));
        System.out.println(fpn.add(b));
        System.out.println(b.subtract(fpn));
        try {
            System.out.println(fpn.format(2));
            System.out.println(fpn.format(4));
            System.out.println(fpn.format(0));
            System.out.println(fpn.format(-1));
        } catch (NotValidFormatException e) {
            e.printStackTrace();
        }
        ArrayList<FixedPointNumber[]> list = DataGenerator.generate(10);
        for (FixedPointNumber[] x : list) {
            System.out.println(x[0] + " + " + x[1] + " = " + x[2]);
        }
        try {
            list = DataGenerator.loadData("D:\\Program Files\\Additor_1_0\\asset\\ques\\20181120-07-31-00.quesdat");
            for (FixedPointNumber[] x : list) {
                System.out.println(x[0] + " + " + x[1] + " = " + x[2]);
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        DataGenerator.update(100, 10);
        int[] dat = DataGenerator.queryProfile();
        System.out.println("Num : " + dat[0] + "\ncorrect: " + dat[1]);
        DataGenerator.update(2, 2);
        dat = DataGenerator.queryProfile();
        System.out.println("Num : " + dat[0] + "\ncorrect: " + dat[1]);
    }
}

/**
 * 定点数类,实现定点数的构造、随机生成、加法、减法、格式化。
 */
class FixedPointNumber implements Comparable<FixedPointNumber>, Serializable {
    private int value;
    private int pointPos;
    private static int seed;
    private boolean isNegative;

    /*
     * 静态初始化块初始化随机数种子
     */
    static {
        seed = (int) (Math.random() * 1000 + 1);
    }

    /**
     * 伪随机数生成器。生成含有pointPos位小数的定点数。 使用伪随机数公式 x = (a * seed + offset) % mod生成随机数。
     * Note: 此版本仅支持生成非负数
     *
     * @param pointPos
     *            小数点位置
     * @param lowerBound
     *            生成数的下界
     * @param upperBound
     *            生成数的上界
     * @return FixedPointNumber 返回构造出的定点数
     */
    public static FixedPointNumber random(int pointPos, int lowerBound, int upperBound) {
        int mod = upperBound - lowerBound;
        for (int i = 0; i < pointPos; i++) {
            mod *= 10;
            lowerBound *= 10;
        }
        seed = (17 * seed + 13) % mod + lowerBound;
        return new FixedPointNumber(seed, pointPos);
    }

    private FixedPointNumber() {
        this.pointPos = 0;
        this.value = 0;
        this.isNegative = false;
    }

    /**
     * 拷贝构造函数。实现深层复制。
     *
     * @param fpn
     *            要被拷贝的对象。
     */
    FixedPointNumber(FixedPointNumber fpn) {
        this.value = fpn.value;
        this.isNegative = fpn.isNegative;
        this.pointPos = fpn.pointPos;
    }

    /**
     * 含有两个参数的构造方法,用来构造一个定点数。 pointPos = 0表示小数点在最右端,此时这个数就是一个整数。
     *
     * @param value
     *            存储的数
     * @param pointPos
     *            小数点位置
     */
    FixedPointNumber(int value, int pointPos) {
        if (value < 0) {
            value = -value;
            isNegative = true;
        } else {
            isNegative = false;
        }
        this.value = value;
        this.pointPos = pointPos;
    }

    /**
     * 使用字符串表示的小数来构造定点数对象的构造器。
     *
     * @param str
     *            使用字符串表示的小数
     * @throws NotValidPointNumberException
     *             当作为参数的字符串不是一个合法的小数时抛出此异常
     */
    FixedPointNumber(String str) throws NotValidPointNumberException {
        int startPos = 0;
        if (str.charAt(0) == '-') {
            this.isNegative = true;
            startPos = 1;
        } else {
            this.isNegative = false;
        }
        // 检查数据有效性
        // if the str is "-"
        if (startPos == str.length())
            throw new NotValidPointNumberException();
        for (int i = startPos; i < str.length(); i++) {
            if (!Character.isDigit(str.charAt(i)) && str.charAt(i) != '.')
                throw new NotValidPointNumberException();
        }
        int pos = str.indexOf('.');
        // 如果存在两个及以上数目小数点,也要抛出异常
        // e.g. "12.3.56"
        if (str.lastIndexOf('.') != pos)
            throw new NotValidPointNumberException();
        this.pointPos = pos >= 0 ? str.length() - pos - 1 : 0;
        if (pos == startPos)
            str = str.substring(1);
        else if (pos == str.length() - 1)
            str = str.substring(startPos, str.length() - 1);
        else if (pos >= 1)
            str = str.substring(startPos, pos) + str.substring(pos + 1, str.length());
        this.value = Integer.parseInt(str);
    }

    /**
     * 计算两个定点数的相加。 需要将小数点位置对齐。
     *
     * @param fpn
     *            加数
     * @return FixedPointNumber 返回两数相加的结果
     * @throws NullPointerException
     *             当传入参数为null时抛出
     */
    FixedPointNumber add(FixedPointNumber fpn) throws NullPointerException {
        if (fpn == null)
            throw new NullPointerException();
        FixedPointNumber ret = new FixedPointNumber();
        FixedPointNumber a = new FixedPointNumber(this);
        FixedPointNumber b = new FixedPointNumber(fpn);
        ret.pointPos = Math.max(a.pointPos, b.pointPos);
        if (a.pointPos < ret.pointPos) {
            int t = ret.pointPos - a.pointPos;
            for (int i = 0; i < t; i++)
                a.value *= 10;
        }
        if (b.pointPos < ret.pointPos) {
            int t = ret.pointPos - b.pointPos;
            for (int i = 0; i < t; i++)
                b.value *= 10;
        }
        if (a.isNegative)
            a.value = -a.value;
        if (b.isNegative)
            b.value = -b.value;
        ret.value = a.value + b.value;
        if (ret.value < 0) {
            ret.value = -ret.value;
            ret.isNegative = true;
        } else {
            ret.isNegative = false;
        }
        return ret;
    }

    /**
     * 实现两个定点数的减法 需要将小数点位置对齐
     *
     * @param fpn
     *            减数
     * @return FixedPointNumber 返回两数相减的结果
     * @throws NullPointerException
     *             减数传入null时抛出此异常
     */
    FixedPointNumber subtract(FixedPointNumber fpn) throws NullPointerException {
        if (fpn == null)
            throw new NullPointerException();
        FixedPointNumber ret = new FixedPointNumber();
        FixedPointNumber a = new FixedPointNumber(this);
        FixedPointNumber b = new FixedPointNumber(fpn);
        ret.pointPos = Math.max(a.pointPos, b.pointPos);
        if (a.pointPos < ret.pointPos) {
            int t = ret.pointPos - a.pointPos;
            for (int i = 0; i < t; i++)
                a.value *= 10;
        }
        if (b.pointPos < ret.pointPos) {
            int t = ret.pointPos - b.pointPos;
            for (int i = 0; i < t; i++)
                b.value *= 10;
        }
        if (a.isNegative)
            a.value = -a.value;
        if (b.isNegative)
            b.value = -b.value;
        ret.value = a.value - b.value;
        if (ret.value < 0) {
            ret.value = -ret.value;
            ret.isNegative = true;
        } else {
            ret.isNegative = false;
        }
        return ret;
    }

    /**
     * 小数位数格式化方法
     *
     * @param pointPos
     * @return FixedPointNumber 返回格式化后的定点数
     * @throws NotValidFormatException
     *             格式不规范,参数小于0时抛出此异常
     */
    public FixedPointNumber format(int pointPos) throws NotValidFormatException {
        if (pointPos < 0)
            throw new NotValidFormatException();
        if (pointPos == this.pointPos) {
            return new FixedPointNumber(this);
        } else if (pointPos < this.pointPos) {
            int t = this.pointPos - pointPos;
            int newVal = this.value;
            while (t-- != 0)
                newVal /= 10;
            return new FixedPointNumber(newVal, pointPos);
        } else {
            int t = pointPos - this.pointPos;
            int newVal = this.value;
            while (t-- != 0)
                newVal *= 10;
            return new FixedPointNumber(newVal, pointPos);
        }
    }

    /**
     * 重载equals方法。
     *
     * @param fpn
     *            被比较的对象
     * @return Boolean 相同返回真,不同返回假
     */
    @Override
    public boolean equals(Object fpn) {
        if (fpn == null)
            return false;
        if (fpn.getClass() != this.getClass())
            return false;
        FixedPointNumber tmp = new FixedPointNumber((FixedPointNumber) fpn);
        FixedPointNumber self = new FixedPointNumber(this);
        if (tmp.isNegative != self.isNegative)
            return false;
        while (tmp.value % 10 == 0) {
            tmp.value /= 10;
            tmp.pointPos--;
        }
        return this.pointPos == tmp.pointPos && this.value == tmp.value;
    }

    /**
     * 重载hash值方法。 保证对于所有调用equals方法返回为真的对象具有相同的hash值。 反之具有不同的hash值。
     *
     * @return int 返回此对象的hash值。
     */
    @Override
    public int hashCode() {
        int tmpvalue = this.value;
        int tmppos = this.pointPos;
        while (tmpvalue % 10 == 0) {
            tmpvalue /= 10;
            tmppos--;
        }
        return Objects.hash(tmpvalue, tmppos, isNegative);
    }

    /**
     * 重载toString方法
     *
     * @return String 此对象的字符串表示。
     */
    @Override
    public String toString() {
        if (pointPos == 0)
            return (this.isNegative ? "-" : "") + value;
        int rate = 1;
        for (int i = 0; i < pointPos; i++) {
            rate *= 10;
        }
        String ret = this.isNegative ? "-" : "";
        ret += value / rate + ".";
        String tmp = "";
        int decimal = value % rate;
        int cnt = 0;
        while (decimal != 0) {
            tmp += decimal % 10;
            decimal /= 10;
            cnt++;
        }
        for (int i = cnt; i < pointPos; i++)
            ret += "0";
        for (int i = tmp.length() - 1; i >= 0; i--)
            ret += tmp.charAt(i);
        return ret;
    }

    /**
     * 重载比较函数
     */
    @Override
    public int compareTo(FixedPointNumber fpn) {
        FixedPointNumber a = new FixedPointNumber(this);
        FixedPointNumber b = new FixedPointNumber(fpn);
        int pos = Math.max(a.pointPos, b.pointPos);
        if (a.pointPos < pos) {
            int t = pos - a.pointPos;
            for (int i = 0; i < t; i++)
                a.value *= 10;
        }
        if (b.pointPos < pos) {
            int t = pos - b.pointPos;
            for (int i = 0; i < t; i++)
                b.value *= 10;
        }
        if (a.isNegative)
            a.value = -a.value;
        if (b.isNegative)
            b.value = -b.value;
        return Integer.compare(a.value, b.value);
    }

    public int getValue() {
        return (this.isNegative ? -1 : 1) * value;
    }
}

class NotValidPointNumberException extends Exception {}
class NotValidFormatException extends Exception {}

class DataGenerator {

    /**
     * 隐藏构造器,禁止构造此类对象
     */
    private DataGenerator() {}

    /**
     * 用来随机生成问题
     * @param num 生成问题的数量
     * @return 使用ArrayList存储的问题题集
     */
    public static ArrayList<FixedPointNumber[]> generate(int num) {
        ArrayList ret = new ArrayList<FixedPointNumber[]>();
        // 题目应当是一个三元组,位置0和1上是题目的两个加数,2上式正确结果
        FixedPointNumber[] dat = new FixedPointNumber[3];
        try {
            while (num-- != 0) {
                dat[2] = FixedPointNumber.random(2, 20, 100).format(2);
                dat[0] = FixedPointNumber.random(2, 10, dat[2].getValue() / 100 / 2).format(2);
                dat[1] = dat[2].subtract(dat[0]).format(2);
                ret.add(dat.clone());
            }
        } catch (NotValidFormatException e) {
            e.printStackTrace();
        }
        saveData(ret);
        return ret;
    }

    /**
     * 保存题组
     * @param list 需要序列化的题组
     * @return 为真表示成功保存,为假表示保存失败
     */
    public static boolean saveData(ArrayList<FixedPointNumber[]> list) {
        try {
            String to = "D:\\Program Files\\Additor_1_0\\asset\\ques\\";
            File f = new File(to);
            if (!f.exists()) f.mkdirs();
            to += new SimpleDateFormat("yyyyMMdd-hh-mm-ss").format(new Date())
                    + ".quesdat";
            f = new File(to);
            if (!f.exists()) f.createNewFile();
            FileOutputStream fos = new FileOutputStream(f);
            ObjectOutputStream oos = new ObjectOutputStream(fos);
            oos.writeObject(list);
            oos.close();
            fos.close();
            return true;
        } catch (IOException e) {
            e.printStackTrace();
            return false;
        }
    }

    /**
     * 题集载入
     * @param src
     * @return ArraList存储的题集
     * @throws java.io.FileNotFoundException 文件不存在
     * @throws java.io.IOException
     * @throws ClassNotFoundException
     */
    public static ArrayList<FixedPointNumber[]> loadData(String src)
            throws java.io.FileNotFoundException, java.io.IOException, ClassNotFoundException {
        File f = new File(src);
        FileInputStream fis = new FileInputStream(src);
        ObjectInputStream ois = new ObjectInputStream(fis);
        ArrayList<FixedPointNumber[]> ret = (ArrayList<FixedPointNumber[]>)ois.readObject();
        return ret;
    }

    /**
     * 更新当前成绩到文件中
     * @param num 新增加的题数
     * @param correct 新正确的题数
     * @return
     */
    public static boolean update(int num, int correct) {
        int prenum = 0, precnt = 0;
        String to = "D:\\Program Files\\Additor_1_0\\asset";
        try {
            File f = new File(to);
            if (!f.exists()) f.mkdirs();
            to += "\\profile.dat";
            f = new File(to);
            if (!f.exists()) {
                f.createNewFile();
            } else {
                int[] dat = queryProfile();
                prenum = dat[0];
                precnt = dat[1];
            }
            prenum += num;
            precnt += correct;
            FileWriter fw = new FileWriter(f);
            fw.append("All: ");
            fw.append(new Integer(prenum).toString());
            fw.append("\r\nCorrect: ");
            fw.append(new Integer(precnt).toString());
            fw.append("\r\n");
            fw.flush();
            fw.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return false;
    }

    /**
     * 读取以前的成绩
     * @return 两个元素的数组,0号位置存储总题数,1号位置存储正确数
     */
    public static int[] queryProfile() {
        int[] ret = new int[2];
        try {
            Scanner in = new Scanner(new File("D:\\Program Files\\Additor_1_0\\asset\\profile.dat"));
            in.next();
            ret[0] = in.nextInt();
            in.next();
            ret[1] = in.nextInt();
            in.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        return ret;
    }
}

class PaintPanel extends JComponent implements Runnable, KeyListener {

    /**
     * 大小
     */
    private int xWindowSize;
    private int yWindowSize;
    /**
     * 两个背景的水平位置坐标,以及其图片资源的引用
     */
    private int cloudPos[];
    private int cloudY[];
    private int soilY;
    private int soilHeight;
    private int grassY;
    private int grassHeight;
    private int boardX;
    private int boardY;
    private Color sky;
    private Color cloudNorm;
    private Color cloudDark;
    private Color soilNorm;
    private Color soilDark;
    private Color grassNorm;
    private Color grassMid;
    private Color grassDark;
    private Color board;
    private Color wordColor;
    private Color maskColor;
    private int maskAlpha;
    private Font font;
    private Font hint;
    private int fontAX;
    private int fontAY;
    private int fontBX;
    private int fontBY;
    private int fontCX;
    private int fontCY;
    private int lifeX;
    private int lifeY;
    private int lifeHeight;
    private int lifeWidth;
    private int crtLifeLen;
    private int lifeDiff;
    // questions
    private int correct;
    private int quesCnt;
    private ArrayList<FixedPointNumber[]> list;
    private String a, b, ans;
    private int quesPos;
    private int hp;
    private int tmphp;
    private int maskDiff;
    private int correctFlag;
    private int wrongFlag;
    private MainFrame frame;
    private Thread myThread;
    private boolean notFinish;

    PaintPanel(int xWindowSize, int yWindowSize, int hp, int quesCnt, MainFrame frame, ArrayList<FixedPointNumber[]> list) {
        super();
        this.frame = frame;
        notFinish = true;
        // init the window size
        this.xWindowSize = xWindowSize;
        this.yWindowSize = yWindowSize;
        // load the background
        Toolkit tk = Toolkit.getDefaultToolkit();
        sky = new Color(78, 173, 245);
        cloudDark = new Color(149, 210, 240);
        cloudNorm = new Color(245, 245, 245);
        soilNorm = new Color(156, 91, 73);
        soilDark = new Color(103, 59, 50);
        grassNorm = new Color(139, 200, 60);
        grassMid = new Color(109, 164, 37);
        grassDark = new Color(67, 118, 13);
        board = new Color(245, 245, 245, 75);
        wordColor = new Color(255, 215, 0);
        maskAlpha = 0;
        maskColor = new Color(0, 0, 0, maskAlpha);
        maskDiff = 80 / (hp - 1);
        font = new Font("SAO UI", Font.PLAIN, 70);
        hint = new Font("SAO UI", Font.BOLD, 90);
        cloudPos = new int[3];
        cloudPos[0] = 0;
        cloudPos[1] = xWindowSize / 3;
        cloudPos[2] = xWindowSize / 4 * 3;
        cloudY = new int[3];
        int skyAreaHeight = yWindowSize / 4 * 3;
        cloudY[0] = skyAreaHeight / 5 * 3;
        cloudY[1] = skyAreaHeight / 3;
        cloudY[2] = skyAreaHeight / 6;
        soilY = skyAreaHeight;
        soilHeight = yWindowSize - soilY;
        grassY = soilY;
        grassHeight = soilHeight / 4;
        boardX = xWindowSize / 4;
        boardY = yWindowSize / 6;
        // init question
        this.quesCnt = quesCnt;
        this.correct = 0;
        if (list == null) {
            this.list = DataGenerator.generate(10);
        } else {
            this.list = list;
        }
        this.quesPos = 0;
        a = list.get(quesPos)[0].toString();
        b = list.get(quesPos)[1].toString();
        ans = "";
        this.hp = hp;
        this.tmphp = hp;
        this.lifeHeight = boardY / 3;
        this.lifeWidth = xWindowSize / 3;
        this.lifeX = xWindowSize / 3;
        this.lifeY = boardY / 3;
        this.crtLifeLen = lifeWidth;
        lifeDiff = lifeWidth / hp;
        correctFlag = 0;
        wrongFlag = 0;

        setVisible(true);
        myThread = new Thread(this);
        myThread.start();
        this.setFocusable(true);
        this.addKeyListener(this);
    }

    private void drawState() {
        renderState();
        updateState();
    }

    private void updateState() {
        // update the cloud position
        for (int i = 0; i < cloudPos.length; i++) {
            cloudPos[i] -= 2 * (i + 1);
            if (cloudPos[i] <= -150) cloudPos[i] = xWindowSize;
        }
    }

    private void renderState() {
        repaint();
    }

    @Override
    public void paint(Graphics g) {
        Graphics2D g2 = (Graphics2D)g;
        g2.setColor(this.sky);
        g2.fillRect(0, 0, xWindowSize, yWindowSize);
        // draw the clouds
        for (int i = 0; i < cloudPos.length; i++) {
            g2.setColor(cloudNorm);
            g2.fillRect(cloudPos[i] + 50, cloudY[i] + 30, 50, 50);
            g2.fillRect(cloudPos[i] + 20, cloudY[i] + 50, 30, 35);
            g2.fillRect(cloudPos[i] + 50, cloudY[i] + 60, 10, 30);
            g2.fillRect(cloudPos[i] + 100, cloudY[i] + 45, 30, 45);
            g2.setColor(cloudDark);
            g2.fillRect(cloudPos[i] + 20, cloudY[i] + 50, 5, 35);
            g2.fillRect(cloudPos[i] + 25, cloudY[i] + 60, 5, 25);
            g2.fillRect(cloudPos[i] + 30, cloudY[i] + 70, 5, 15);
            g2.fillRect(cloudPos[i] + 35, cloudY[i] + 80, 15, 5);
            g2.fillRect(cloudPos[i] + 50, cloudY[i] + 30, 8, 40);
            g2.fillRect(cloudPos[i] + 50, cloudY[i] + 30, 30, 5);
            g2.setColor(sky);
            g2.fillRect(cloudPos[i] + 20, cloudY[i] + 70, 8, 15);
            g2.fillRect(cloudPos[i] + 120, cloudY[i] + 85, 10, 5);
        }
        g2.setColor(soilNorm);
        g2.fillRect(0, soilY, xWindowSize, soilHeight);
        g2.setColor(soilDark);
        g2.fillRect(0, soilY, xWindowSize, soilHeight / 3);
        g2.setColor(grassNorm);
        g2.fillRect(0, grassY, xWindowSize, grassHeight);
        g2.setColor(grassMid);
        g2.fillRect(0, grassY + grassHeight / 5 * 2, xWindowSize, grassHeight / 5 * 3);
        g2.setColor(grassDark);
        g2.fillRect(0, grassY + grassHeight / 5 * 4, xWindowSize, grassHeight / 5 * 2);
        // draw the board
        g2.setColor(board);
        g2.fillRect(boardX, boardY, boardX * 2, boardY * 4);
        g2.setFont(font);
        g2.setColor(wordColor);
        // paint the word
        FontMetrics fm = g2.getFontMetrics();
        int areaHeight = boardY * 20 / 6;
        int areaWidth = boardX / 2 * 3;
        int strWidth = fm.stringWidth("XX.XX");
        this.fontAX = (areaWidth - strWidth) / 2 + (boardX * 2 - areaWidth) / 2 + boardX;
        this.fontAY = (boardY * 4 - areaHeight) / 2 + boardY + fm.getAscent() / 2;
        int diffY = fontAY - boardY + fm.getAscent() / 3;
        this.fontBX = this.fontAX;
        this.fontBY = this.fontAY + diffY;
        this.fontCX = this.fontBX;
        this.fontCY = this.fontBY + diffY * 2;
        g2.drawString(this.a, fontAX, fontAY);
        g2.drawString(this.b, fontBX, fontBY);
        g2.drawString(this.ans, fontCX, fontCY);
        // draw hint
        g2.setFont(hint);
        fm = g2.getFontMetrics();
        if (correctFlag > 0) {
            g2.setColor(new Color(127, 255, 0));
            int wid = fm.stringWidth("Correct                Answer");
            int tx = (xWindowSize - wid) / 2;
            int ty = (yWindowSize - fm.getDescent()) / 2;
            g2.drawString("Correct                Answer", tx, ty);
            correctFlag--;
            System.out.println(correctFlag);
        }
        if (wrongFlag > 0) {
            g2.setColor(new Color(227, 38, 54));
            int wid = fm.stringWidth("Wrong                Answer");
            int tx = (xWindowSize - wid) / 2;
            int ty = (yWindowSize - fm.getDescent()) / 2;
            g2.drawString("Wrong                Answer", tx, ty);
            wrongFlag--;
        }
        if (correctFlag < 0) correctFlag = 0;
        if (wrongFlag < 0) wrongFlag = 0;
        // draw life
        g2.setColor(Color.RED);
        g2.drawRect(lifeX, lifeY, lifeWidth, lifeHeight);
        g2.fillRect(lifeX, lifeY, crtLifeLen, lifeHeight);
        // draw mask
        maskColor = new Color(0, 0, 0, maskAlpha);
        g2.setColor(maskColor);
        g2.fillRect(0, 0, xWindowSize, yWindowSize);
        g2.finalize();
    }

    @Override
    public void run() {
        while (notFinish) {
            drawState();
            try {
                // 60 FPS
                Thread.sleep(16);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    @Override
    public Dimension getPreferredSize() {
        return new Dimension(xWindowSize, yWindowSize);
    }

    private static String mapList[] = {"0", "1", "2", "3", "4", "5",
        "6", "7", "8", "9"};
    @Override
    public void keyPressed(KeyEvent e) {
        int code = e.getKeyCode();
        // backSpace is 8
        if (code == 8) {
            ans = "";
            tmphp--;
            maskAlpha += maskDiff;
            crtLifeLen -= lifeDiff;
        } else if (code >= 48 && code <= 57) {
            ans += mapList[code - 48];
        } else if (code >= 96 && code <= 105) {
            ans += mapList[code - 96];
        }
        if (ans.length() == 2) ans += ".";
        if (ans.length() == 5) {
//            System.out.println(ans);
//            System.out.println(list.get(quesPos)[2]);
            if (ans.equals(list.get(quesPos)[2].toString())) {
//                System.out.println("Here");
                correctFlag = 100;
                correct++;
                tmphp = hp;
                quesPos++;
                // finish all question
                if (quesPos == quesCnt) {
                    frame.showMenu();
                    notFinish = false;
                    return;
                }
                a = list.get(quesPos)[0].toString();
                b = list.get(quesPos)[1].toString();
                maskAlpha = 0;
                crtLifeLen = lifeWidth;
            } else {
                maskAlpha += maskDiff;
                tmphp--;
                crtLifeLen -= lifeDiff;
            }
            ans = "";
        }
        if (tmphp == 0) {
            wrongFlag = 100;
            System.out.println("Here");
            quesPos++;
            // finish all question
            if (quesPos == quesCnt) {
                frame.showMenu();
                notFinish = false;
                return;
            }
            a = list.get(quesPos)[0].toString();
            b = list.get(quesPos)[1].toString();
            tmphp = hp;
            maskAlpha = 0;
            crtLifeLen = lifeWidth;
        }
        System.out.println(maskAlpha);
    }

    @Override
    public void keyTyped(KeyEvent e) {

    }

    @Override
    public void keyReleased(KeyEvent e) {

    }
}

class MenuPanel extends JPanel {
    private int xWindowSize;
    private int yWindowSize;
    private MainFrame frame;
    // start
    private JButton button0;
    // load data
    private JButton button1;
    // show profile
    private JButton button2;
    MenuPanel(int xWindowSize, int yWIndowSize, MainFrame frame) {
        this.xWindowSize = xWindowSize;
        this.yWindowSize = yWIndowSize;
        button0 = new JButton("新的测试");
        button1 = new JButton("加载记录");
        button2 = new JButton("个人成绩");
        button0.addActionListener(new ActionListener() {
            @Override
            public void actionPerformed(ActionEvent e) {
                frame.newTest();
            }
        });
        button1.addActionListener(new ActionListener() {
            @Override
            public void actionPerformed(ActionEvent e) {
                JFileChooser jfc = new JFileChooser("D:\\Program Files\\Additor_1_0\\asset\\ques");
                jfc.setFileSelectionMode(JFileChooser.FILES_ONLY);
                int result = jfc.showOpenDialog(null);
                if (result == JFileChooser.APPROVE_OPTION) {
                    File file = jfc.getSelectedFile();
                    ArrayList<FixedPointNumber[]> list;
                    try {
                        list = DataGenerator.loadData(file.getPath());
                        frame.loadTest(list);
                    } catch (Exception ee) {
                        ee.printStackTrace();
                    }
                }
            }
        });
        // set the position of buttons
        int buttonWidth = xWindowSize / 4;
        int buttonHeight = yWindowSize / 7;
        int buttonX = (xWindowSize - buttonWidth) / 2;
        int buttonY = buttonHeight;
        button0.setBounds(buttonX, buttonY, buttonWidth, buttonHeight);
        buttonY += buttonHeight * 2;
        button1.setBounds(buttonX, buttonY, buttonWidth, buttonHeight);
        buttonY += buttonHeight * 2;
        button2.setBounds(buttonX, buttonY, buttonWidth, buttonHeight);
        this.setLayout(null);
        this.add(button0);
        this.add(button1);
        this.add(button2);
        this.setVisible(true);
    }

    @Override
    public Dimension getPreferredSize() {
        return new Dimension(xWindowSize, yWindowSize);
    }
}

class MainFrame extends JFrame {
    private Container frameContainer;
    private int xWindowSize;
    private int yWindowSize;
    private int xScreenSize;
    private int yScreenSize;
    private Image icon;
    public MainFrame() {
        Toolkit tk = Toolkit.getDefaultToolkit();
        icon = tk.getImage("D:\\Program Files\\Additor_1_0\\asset\\pics\\Author.png");
        this.xScreenSize = tk.getScreenSize().width;
        this.yScreenSize = tk.getScreenSize().height;
        this.xWindowSize = this.xScreenSize / 2;
        this.yWindowSize = this.yScreenSize / 2;
        this.setBounds(this.xWindowSize / 2, this.yWindowSize / 2, this.xWindowSize, this.yWindowSize);
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        this.setTitle("Additor V1.0");
        this.setIconImage(this.icon);
        this.setResizable(false);
        frameContainer = this.getContentPane();
        frameContainer.add(new MenuPanel(xWindowSize, yWindowSize, this));
        this.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        this.setVisible(true);
    }

    public void newTest() {
        JDialog hpQuest = new JDialog();
        hpQuest.setTitle("提示");
        hpQuest.setLayout(null);
        JTextField words = new JTextField(10);
        JButton yes = new JButton();
        yes.setText("确定");
        words.setBounds(320 / 3, 150 / 5, 320 / 3, 150 / 5);
        yes.setBounds(320 / 5 * 2, 150 / 7 * 3, 320 / 5, 150 / 7);
        hpQuest.add(words);
        hpQuest.add(yes);
        hpQuest.setBounds((xScreenSize - 320) / 2, (yScreenSize - 150) / 2, 320, 150);
        hpQuest.setSize(320, 150);
        hpQuest.setVisible(true);
        yes.addActionListener(new startListener(this, hpQuest, words, null));
    }

    public void loadTest(ArrayList<FixedPointNumber[]> list) {
        JDialog hpQuest = new JDialog();
        hpQuest.setTitle("提示");
        hpQuest.setLayout(null);
        JTextField words = new JTextField(10);
        JButton yes = new JButton();
        yes.setText("确定");
        words.setBounds(320 / 3, 150 / 5, 320 / 3, 150 / 5);
        yes.setBounds(320 / 5 * 2, 150 / 7 * 3, 320 / 5, 150 / 7);
        hpQuest.add(words);
        hpQuest.add(yes);
        hpQuest.setBounds((xScreenSize - 320) / 2, (yScreenSize - 150) / 2, 320, 150);
        hpQuest.setSize(320, 150);
        hpQuest.setVisible(true);
        yes.addActionListener(new startListener(this, hpQuest, words, list));
    }

    public void showMenu() {
        frameContainer.removeAll();
        frameContainer.add(new MenuPanel(xWindowSize, yWindowSize, this));
        this.repaint();
        this.validate();
    }

    public int getXWindowSize() {
        return xWindowSize;
    }
    public int getYWindowSize() {
        return yWindowSize;
    }
}

class startListener implements ActionListener {
    private MainFrame frame;
    private JDialog hpQuest;
    private JTextField words;
    private ArrayList<FixedPointNumber[]> list;
    startListener(MainFrame frame, JDialog hpQuest,
                  JTextField words, ArrayList<FixedPointNumber[]> list) {
        this.frame = frame;
        this.hpQuest = hpQuest;
        this.words = words;
        this.list = list;
    }
    @Override
    public void actionPerformed(ActionEvent e) {
        int hp = Integer.parseInt(words.getText());
        hpQuest.dispose();
        frame.getContentPane().removeAll();
        if (list != null)
            frame.getContentPane().add(new PaintPanel(frame.getXWindowSize(), frame.getYWindowSize(), hp, 10, frame, list));
        else
            frame.getContentPane().add(new PaintPanel(frame.getXWindowSize(), frame.getYWindowSize(), hp, 10, frame, null));
        frame.repaint();
        frame.validate();
    }
}

 

MPICH笔记(五):聚合通信

MPICH笔记(五):聚合通信

官方文档索引:

https://www.mpich.org/static/docs/v3.2/

概述:

大体有三类聚合通信\(^{[1]}\):

1) 同步。

2)数据迁移。如:MPI_Bcast广播、MPI_Scatter散射、MPI_Gather聚合、MPI_Allgather聚合结果存到每个进程、AlltoAll数据全交换(有点像矩阵的转职)。

3)聚合运算。如归约(MPI_Reduce)、扫描(MPI_Scan)。

代码:

#include <iostream>
#include "mpi.h"

using namespace std;

int main(void) {
    MPI_Init(nullptr, nullptr);
    int rank;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    printf("Rank %d > Reach %d\n", rank, __LINE__);
    MPI_Barrier(MPI_COMM_WORLD);
    printf("Rank %d > Reach %d\n", rank, __LINE__);
    MPI_Finalize();
    return 0;
}

 

执行命令:

mpiexec -n 8 ./文件名

 

输出:

Rank 1 > Reach 10
Rank 2 > Reach 10
Rank 6 > Reach 10
Rank 3 > Reach 10
Rank 4 > Reach 10
Rank 7 > Reach 10
Rank 0 > Reach 10
Rank 5 > Reach 10
Rank 0 > Reach 12
Rank 1 > Reach 12
Rank 2 > Reach 12
Rank 3 > Reach 12
Rank 4 > Reach 12
Rank 5 > Reach 12
Rank 6 > Reach 12
Rank 7 > Reach 12

去掉MPI_Barrier(MPI_COMM_WORLD)之后

输出:

Rank 2 > Reach 10
Rank 2 > Reach 12
Rank 3 > Reach 10
Rank 3 > Reach 12
Rank 0 > Reach 10
Rank 0 > Reach 12
Rank 5 > Reach 10
Rank 5 > Reach 12
Rank 7 > Reach 10
Rank 7 > Reach 12
Rank 4 > Reach 10
Rank 4 > Reach 12
Rank 6 > Reach 10
Rank 6 > Reach 12
Rank 1 > Reach 10
Rank 1 > Reach 12

可以看到,由于Barrier将所有进程进行了同步,所以使用Barrier的代码所有进程保证都先执行了第10行的内容,之后才是第12行的内容,而不使用Barrier的代码自然没有这个限制。

 

参考资料:

[1]: 《并行计算的编程模型(Programming Models for Parallel Computing)》第一版,Pavan Balaji [美]、美国阿贡国家实验室编著,张云泉、李士刚、逄仁波、袁良译,机械工业出版社。

MPICH笔记(四):非阻塞通信

MPICH笔记(四):非阻塞通信

官方文档链接:

int MPI_Isend(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm, MPI_Request *request)

int MPI_Irecv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Request *request)

int MPI_Wait(MPI_Request *request, MPI_Status *status)

int MPI_Waitall(int count, MPI_Request array_of_requests[],  MPI_Status array_of_statuses[])

int MPI_Request_free(MPI_Request *request)

概述:

MPI的非阻塞通信的函数的命名规则一般是在对应的阻塞式函数的第二个单词处开头添加大写I,表示这个函数时非阻塞的,并且函数的参数会有一定的调整(如MPI_Irecv()比MPI_Recv()少了输出MPI_Status类型数据的参数),并都添加上了输出MPI_Request类型数据的参数,用来在之后的程序中进行判断通信是否完成。

代码:

#include <iostream>
#include "mpi.h"

using namespace std;

typedef long long LL;

int main(void) {
    MPI_Init(nullptr, nullptr);
    int rank, size;
    int dat;
    MPI_Request *r;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);
    if (!rank) {
        r = new MPI_Request[size - 1];
        dat = 233;
        /*
        * 进行非阻塞式点对点发送
        * 与具有rank i的进程的通信状态存储在r[i - 1]中
        * 即地址r + i - 1上     
        */
        for (int i = 1; i < size; i++) {
            MPI_Isend(&dat, 1, MPI_INT, i, 0, MPI_COMM_WORLD, r + i - 1);
        }
        /*
        * 进行一个比较耗时的计算,这里使用循环来暴力计算10到100的阶乘之和
        * 对1000000007取模的值
        */
        LL mod = 1000000007ll;
        LL sum = 0, ret;
        for (int i = 10; i <= 100; i++) {
            ret = 1;
            for (int j = 1; j <= i; j++) {
                ret = (ret * i) % mod;
            }
            sum = (sum + ret) % mod;
            printf("Rank 0 > do operation %d!\n", i);
        }
        MPI_Waitall(size - 1, r, MPI_STATUSES_IGNORE);
        delete [] r;
    } else {
        r = new MPI_Request();
        MPI_Irecv(&dat, 1, MPI_INT, 0, MPI_ANY_TAG, MPI_COMM_WORLD, r);
        for (int i = 0; i < 1000; i++)
            for (int j = 0; j < 1000000; j++);
        MPI_Wait(r, MPI_STATUS_IGNORE);
        printf("Rank %d > Done!\n", rank);
        delete r;
    }
    MPI_Finalize();
    return 0;
}

执行指令:

mpiexec -n 8 ./文件名

输出:

Rank 0 > do operation 10!
Rank 0 > do operation 11!
Rank 0 > do operation 12!
Rank 0 > do operation 13!
Rank 0 > do operation 14!
Rank 0 > do operation 15!
Rank 0 > do operation 16!
Rank 0 > do operation 17!
Rank 0 > do operation 18!
Rank 0 > do operation 19!
Rank 0 > do operation 20!
Rank 0 > do operation 21!
Rank 0 > do operation 22!
Rank 0 > do operation 23!
Rank 0 > do operation 24!
Rank 0 > do operation 25!
Rank 0 > do operation 26!
Rank 0 > do operation 27!
Rank 0 > do operation 28!
Rank 0 > do operation 29!
Rank 0 > do operation 30!
Rank 0 > do operation 31!
Rank 0 > do operation 32!
Rank 0 > do operation 33!
Rank 0 > do operation 34!
Rank 0 > do operation 35!
Rank 0 > do operation 36!
Rank 0 > do operation 37!
Rank 0 > do operation 38!
Rank 0 > do operation 39!
Rank 0 > do operation 40!
Rank 0 > do operation 41!
Rank 0 > do operation 42!
Rank 0 > do operation 43!
Rank 0 > do operation 44!
Rank 0 > do operation 45!
Rank 0 > do operation 46!
Rank 0 > do operation 47!
Rank 0 > do operation 48!
Rank 0 > do operation 49!
Rank 0 > do operation 50!
Rank 0 > do operation 51!
Rank 0 > do operation 52!
Rank 0 > do operation 53!
Rank 0 > do operation 54!
Rank 0 > do operation 55!
Rank 0 > do operation 56!
Rank 0 > do operation 57!
Rank 0 > do operation 58!
Rank 0 > do operation 59!
Rank 0 > do operation 60!
Rank 0 > do operation 61!
Rank 0 > do operation 62!
Rank 0 > do operation 63!
Rank 0 > do operation 64!
Rank 0 > do operation 65!
Rank 0 > do operation 66!
Rank 0 > do operation 67!
Rank 0 > do operation 68!
Rank 0 > do operation 69!
Rank 0 > do operation 70!
Rank 0 > do operation 71!
Rank 0 > do operation 72!
Rank 0 > do operation 73!
Rank 0 > do operation 74!
Rank 0 > do operation 75!
Rank 0 > do operation 76!
Rank 0 > do operation 77!
Rank 0 > do operation 78!
Rank 0 > do operation 79!
Rank 0 > do operation 80!
Rank 0 > do operation 81!
Rank 0 > do operation 82!
Rank 0 > do operation 83!
Rank 0 > do operation 84!
Rank 0 > do operation 85!
Rank 0 > do operation 86!
Rank 0 > do operation 87!
Rank 0 > do operation 88!
Rank 0 > do operation 89!
Rank 0 > do operation 90!
Rank 0 > do operation 91!
Rank 0 > do operation 92!
Rank 0 > do operation 93!
Rank 0 > do operation 94!
Rank 0 > do operation 95!
Rank 0 > do operation 96!
Rank 0 > do operation 97!
Rank 0 > do operation 98!
Rank 0 > do operation 99!
Rank 0 > do operation 100!
Rank 1 > Done!
Rank 4 > Done!
Rank 6 > Done!
Rank 3 > Done!
Rank 5 > Done!
Rank 2 > Done!
Rank 7 > Done!

非阻塞式通信通常会结合MPI_Wait()或MPI_Waitall()来使用,这样一方面可以避免阻塞式通信带来的延时,可以直接去执行通信环节之后的与前面通信结果无关的操作,又可以在需要使用到前面通信结果之前调用MPI_Wait()或MPI_Waitall()来阻塞,确保通信的完成。

代码概述:

MPI_Waitall()

这个函数的第三个参数是输出参数,是一个具有与第一个参数指定的元素个数的MPI_Status数组,因此,在不需要的时候,这里需要传递的是预定义常量MPI_STATUSES_IGNORE,而不是MPI_STATUS_IGNORE。

MPI_Request_free()

用于释放掉通常由如上类似的非阻塞函数传递的MPI_Request对象,注意,一旦释放后,将不能使用它传递给Wait函数来进行阻塞。

MPICH笔记(三):数据类型

MPICH笔记(三):数据类型

官方文档索引:

int MPI_Type_free(MPI_Datatype *datatype)

int MPI_Type_indexed(int count, const int *array_of_blocklengths, const int *array_of_displacements, MPI_Datatype oldtype, MPI_Datatype *newtype)

int MPI_Type_create_struct(int count, const int array_of_blocklengths[], const MPI_Aint array_of_displacements[], const MPI_Datatype array_of_types[], MPI_Datatype *newtype)

int MPI_Type_commit(MPI_Datatype *datatype)

int MPI_Type_vector(int count, int blocklength, int stride,  MPI_Datatype oldtype, MPI_Datatype *newtype)

 

代码:

#include <iostream>
#include "mpi.h"

using namespace std;

int main(void) {
    MPI_Init(nullptr, nullptr);
    int rank;
    int *dat = new int[9];
    MPI_Datatype *myType = new MPI_Datatype();
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Type_vector(3, 1, 3, MPI_INT, myType);
    MPI_Type_commit(myType);
    if (rank == 1) {
        for (int i = 0; i < 9; i++) dat[i] = 0;
        MPI_Recv(dat, 1, *myType, MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        printf("Rank 1 >\n");
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < 3; j++) {
                if (j) putchar(' ');
                printf("%d", dat[i * 3 + j]);
            }
            putchar('\n');
        }
    } else if (rank == 0) {
        for (int i = 0; i < 9; i++) {
            dat[i] = i + 1;
        }
        printf("Rank 0 >\n");
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < 3; j++) {
                if (j) putchar(' ');
                printf("%d", dat[i * 3 + j]);
            }
            putchar('\n');
        }
        MPI_Send(dat + 1, 1, *myType, 1, 0, MPI_COMM_WORLD);
    }
    MPI_Type_free(myType);
    delete myType;
    delete [] dat;
    MPI_Finalize();
    return 0;
}

运行指令:

mpiexec -n 2 ./文件名

输出信息:

Rank 0 >
1 2 3
4 5 6
7 8 9
Rank 1 >
2 0 0
5 0 0
8 0 0

代码概述:

上面的代码只操作两个进程,进程0生成一个3 * 3的矩阵,并将其第二列发送给进程1,进程1将接收到的这一列放在其自身的第一列上。

MPI_Type_vector()

简单的用来创建新的MPI数据类型的函数,其最后一个参数MPI_Datatype *newtype是要保存新创建好的数据类型的空间的地址。它用一个四元组来描述新的数据类型,依次是新类型包含了几个旧数据类型块、新类型的每个旧数据类型块中要用几个连续的旧类型数据、旧类型数据块的大小、旧类型。

MPI_Type_commit()

用来启用新的数据类型,创建新的数据类型后,需要在使用前将新数据类型的地址传递给这个函数,这个函数会对新构造的数据类型进行分析并在使用新的数据类型进行消息通信前对非连续的数据进行通信优化\(^{[1]}\)。

MPI_Type_free()

用来清除传递给它的地址上存放的新数据类型。

 

 

参考资料:

[1]: 《并行计算的编程模型(Programming Models for Parallel Computing)》第一版,Pavan Balaji [美]、美国阿贡国家实验室编著,张云泉、李士刚、逄仁波、袁良译,机械工业出版社。

MPICH笔记(二):阻塞通信

MPICH笔记(二):阻塞通信

MPICH支持点对点通信,主要使用:

MPI_Send()MPI_Recv(),两个函数一个是发送,一个是接收,官方的文档十分详细,单击前边的这两个函数即可跳转到对应页面。

需要注意的是:MPI_Recv()是阻塞接收函数,保证在接收到消息前不会执行之后的语句;然而,根据文档的描述,MPI_Send()是阻塞发送函数,但是却可能会不发生阻塞,也就是说,函数可能在发送的消息被确认接收前返回。

另一点需要注意的是,因为MPI_Recv()会进行阻塞,所以需要考虑避免死锁现象,即同一时间所有的进程都进行MPI_Recv()操作。如下面的代码将0号进程设置为先发送再接收,而其它进程都是先接收,再发送消息,这样来避免死锁。

代码:

#include <iostream>
#include "mpi.h"

using namespace std;

int main(void) {
    int rank, size;
    int message;
    int from, to;
    MPI_Init(nullptr, nullptr);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);
    from = ((rank - 1) % size + size) % size;
    to = (rank + 1) % size;
    if (rank) {
        MPI_Recv(&message, 1, MPI_INT, from, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        printf("Rank %d > Hello from %d, and the message is %d\n", rank, from, message);
        MPI_Send(&rank, 1, MPI_INT, to, 0, MPI_COMM_WORLD);
    } else {
        MPI_Send(&rank, 1, MPI_INT, to, 0, MPI_COMM_WORLD);
        MPI_Recv(&message, 1, MPI_INT, from, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        printf("Rank %d > Hello from %d, and the message is %d\n", rank, from, message);
    }
    MPI_Finalize();
    return 0;
}

 

输出:

Rank 1 > Hello from 0, and the message is 0
Rank 2 > Hello from 1, and the message is 1
Rank 3 > Hello from 2, and the message is 2
Rank 4 > Hello from 3, and the message is 3
Rank 5 > Hello from 4, and the message is 4
Rank 6 > Hello from 5, and the message is 5
Rank 7 > Hello from 6, and the message is 6
Rank 0 > Hello from 7, and the message is 7

注:这里一共开启了8个进程。

 

代码概述:

int MPI_Recv(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status *status)

前三个参数是表述消息信息的三元组,依此是存储位置、数量(在接收函数中,这个参数表示最大可能的数据长度,需要保证这个大小要不小于实际传递信息的小才可以成功接收数据)、单位大小(数据类型)。source是来源进程的rank,如果不限定来源,可以使用预定义量MPI_ANY_SOURCE。tag表示消息的编号,因为可能有多个来自同一进程的不同作用的信息,如果不限制tag值,可以使用预定义常量MPI_ANY_TAG。status参数是用来获取消息的状态信息,如果不需要关于此消息的状态信息,可以使用预定义常量MPI_STATUS_IGNORE。

关于MPI_Status的详情见:https://docs.microsoft.com/en-us/message-passing-interface/mpi-status-structure

int MPI_Send(const void *buf, int count, MPI_Datatype datatype, int dest, int tag, MPI_Comm comm)

前三个参数仍然是描述信息的三元组,其中count需要是传输信息的实际数量。而后dest是接收进程在comm通信子内的rank。tag要与接收者设置的tag一致,注意,不可以使用MPI_ANY_TAG。

注:关于阻塞,两个阻塞式通信函数都会一直阻塞直到内存申请成功,关于具体细节,此处占坑,日后补充。

MPICH笔记(一):通信子

MPICH笔记:简单使用MPICH(一)

使用MPICH需要导入的头文件:

mpi.h

MPICH函数的命名规则:

MPI_Xxxx_xxx…即以MPI开头第二个单词的首字母大写,后面单词的首字母全部小写,单词之间用_隔开。

MPICH预定义常量命名规则:

MPI_XXXX_XXXX…即以MPI开头,每个字母都大写。

 

示例代码:

#include "mpi.h"
#include <cstdio>
#include <iostream>

using namespace std;

int main(int argc, char **argv) {
    int wrank, wsize, mrank;
    MPI_Init(&argc, &argv);
    /*
    * MPI_COMM_WORLD 和 MPI_COMM_SELF 是两个预定义的通信子(或者叫通信域, communicator),
    * MPI_COMM_WORLD 包含了所有进程,
    * MPI_COMM_SELF 仅包含自身进程。
    */
    MPI_Comm_rank(MPI_COMM_WORLD, &wrank);
    MPI_Comm_size(MPI_COMM_WORLD, &wsize);
    MPI_Comm_rank(MPI_COMM_SELF, &mrank);
    printf("World rank %d, world size %d, self rank %d\n", wrank, wsize, mrank);
    MPI_Finalize();
    return 0;
}

输出:

World rank 2, world size 4, self rank 0
World rank 0, world size 4, self rank 0
World rank 3, world size 4, self rank 0
World rank 1, world size 4, self rank 0

注:输出顺序可能有所不同,因为是由多个进程同时执行,执行先后会有差异。

代码简述:

通信子(通信域,communicator):

MPI的通信子包含一组进程以及一个(隐藏的)通信文本,通信文本的作用在于保证消息和库的一致性\(^{[1]}\)。

MPICH有两个预定义的communicator,分别是MPI_COMM_WORLD和MPI_COMM_SELF,前者包含所有进程,后者只包含自身进程。

MPI_Comm_rank(MPI_Comm, int *)

获得进程在通信子comm中的编号。保存在int *参数所指向的空间中。

MPI_Comm_size(MPI_Comm, int *)

获得通信子comm中进程个数,保存在int *参数所指向的空间中。

MPI_Init(int *, char **)

初始化。

MPI_Finalize()

结束。

注:大多数MPI函数都需要在MPI_Init(int *, char **)之后,且在MPI_Finalize()之前使用。

 

手动创建通信子:

#include "mpi.h"
#include <iostream>

using namespace std;

int main(void) {
    int mrank, wrank, wsize;
    MPI_Comm *evenAndOddComm = new MPI_Comm();
    MPI_Init(nullptr, nullptr);
    MPI_Comm_rank(MPI_COMM_WORLD, &wrank);
    MPI_Comm_size(MPI_COMM_WORLD, &wsize);
    /*
    * create a new commonicator
    */
    if (wrank & 1) MPI_Comm_split(MPI_COMM_WORLD, 1, wrank >> 1, evenAndOddComm);
    else MPI_Comm_split(MPI_COMM_WORLD, 2, wrank >> 1, evenAndOddComm);
    MPI_Comm_rank(*evenAndOddComm, &mrank);
    printf("Message from world rank %d -> world size %d, myColor is %d, and mrank is %d\n",
        wrank, wsize, wrank & 1 ? 1 : 2, mrank);
    MPI_Comm_free(evenAndOddComm);
    delete evenAndOddComm;
    MPI_Finalize();
    return 0;
}

代码简述:

MPI_Comm_split(MPI_Comm, int color, int key, MPI_Comm *out)

得到comm中进程的划分,具有相同color值的通信子会被划分到同一个新的通信子中,out指向的空间存放新的通信子相关的信息,key将是调用该函数的进程在新的通信子中的序号。特别地,color可以指定为MPI_UNDEFINED,但是这样调用函数后得到的out将不可用,如调用MPI_Comm_rank()函数时会产生异常。

MPI_Comm_free(MPI_Comm *)

释放通信子。

 

参考资料:

[1]: 《并行计算的编程模型(Programming Models for Parallel Computing)》第一版,Pavan Balaji [美]、美国阿贡国家实验室编著,张云泉、李士刚、逄仁波、袁良译,机械工业出版社。

LCA(Lowest Common Ancestor)笔记(一)

LCA(Lowest Common Ancestor)笔记(一)

二叉搜索树(Binary Searching Tree)情形:

O(h)的方案,根据LCA的性质我们可以推测出,对于每一个结点,所要寻找的LCA结点一定是一下四种之一:

1. 当前结点本身

2.当前结点的左子树里

3.当前结点的右子树里

4.不存在这样的LCA(当且仅当输入的两个结点至少一个不存在于这棵树中)

对于二叉搜索树(Binary Searching Tree),两个结点分别在当前结点两侧(即一个结点小于等于当前结点,另一个结点大于等于当前结点)时,LCA恰好就是当前结点本身。当两个结点都小于当前结点,则根据BST的性质,可以知道LCA在当前结点的左子树中,反之在右子树中,如果深入到了叶子结点还没有找到,则说明这棵树中不存在当前结点。

递归实现:

#include <iostream>

using namespace std;

struct Node {
    int val;
    Node *left, *right;
    Node(int v = 0) : val(v) {
        left = right = nullptr;
    }
};

Node * lca(Node *root, int n1, int n2) {
    if (!root) return NULL;
    if (n1 < root->val && n2 < root->val)
        return lca(root->left, n1, n2);
    else if (n1 > root->val && n2 > root->val)
        return lca(root->right, n1, n2);
    else
        return root;
}

void delTree(Node *root) {
    if (!root) return;
    delTree(root->left);
    delTree(root->right);
    delete root;
}

int main(void) {
    // init the tree
    Node *root = new Node(20);
    root->left = new Node(8);
    root->right = new Node(22);
    root->left->left = new Node(4);
    root->left->right = new Node(12);
    root->left->right->left = new Node(10);
    root->left->right->right = new Node(14);

    // get LCA
    int n1 = 10, n2 = 14;
    Node *t = lca(root, n1, n2);
    printf("LCA of %d and %d is %d \n", n1, n2, t->val);

    n1 = 14, n2 = 8;
    t = lca(root, n1, n2);
    printf("LCA of %d and %d is %d \n", n1, n2, t->val);

    n1 = 10, n2 = 22;
    t = lca(root, n1, n2);
    printf("LCA of %d and %d is %d \n", n1, n2, t->val);

    // destroy tree
    delTree(root);

    return 0;
}

递推实现:

#include <iostream>

using namespace std;

struct Node {
    int val;
    Node *left, *right;
    Node(int v = 0) : val(v) {
        left = right = nullptr;
    }
};

Node * lca(Node *root, int n1, int n2) {
    while (root) {
        if (n1 < root->val && n2 < root->val)
            root = root->left;
        else if (n1 > root->val && n2 > root->val)
            root = root->right;
        else
            break;
    }
    return root;
}

void delTree(Node *root) {
    if (!root) return;
    delTree(root->left);
    delTree(root->right);
    delete root;
}

int main(void) {
    // init the tree
    Node *root = new Node(20);
    root->left = new Node(8);
    root->right = new Node(22);
    root->left->left = new Node(4);
    root->left->right = new Node(12);
    root->left->right->left = new Node(10);
    root->left->right->right = new Node(14);

    // get LCA
    int n1 = 10, n2 = 14;
    Node *t = lca(root, n1, n2);
    printf("LCA of %d and %d is %d \n", n1, n2, t->val);

    n1 = 14, n2 = 8;
    t = lca(root, n1, n2);
    printf("LCA of %d and %d is %d \n", n1, n2, t->val);

    n1 = 10, n2 = 22;
    t = lca(root, n1, n2);
    printf("LCA of %d and %d is %d \n", n1, n2, t->val);

    // destroy tree
    delTree(root);

    return 0;
}

 

对于更为一般的二叉树,我们的策略如下:

第一种,利用从根结点到两个目标结点的路径。利用这一性质的方案,大都具有O(h)复杂度,有些还需要额外存储空间。

(1)我们可以先走两次路径,分别是从根节点深搜到结点1,以及从根节点深搜到结点2,将两个路径记录到两个数组中,之后从这两个数组最开始的位置同时往后查,直到找到第一个不一样的点,这说明,在从根到两个结点的路径上,从这个结点开始以及之后的所有结点都不相同,而前面路径上的结点都相同,那么LCA就是这样第一个不相同的这组结点之前的一个结点。

(2)我们可以记录parent的位置,由于每个结点有一个独特的存储位置(如果是动态分配存储,这个位置就是地址,而如果使用静态数组进行存储,这个位置就是数组序号了),那么我们可以从第一个结点倒推会根节点,使用STL的map建立一个从结点存储位置值到表示该结点是否存在于路径中的布尔变量的一个映射(相当于一个Hash 表),利用这个,我们再从第二个结点倒推回根节点,路径上每个结点都尝试判断它是否被加入到了map中,找到的第一个存在于map中的结点,它便是LCA,使用set可以达到同样的效果。

代码:

#include <iostream>
#include <map>
#include <vector>

using namespace std;

struct Node {
    int val;
    Node *left, *right, *parent;
    Node(int v = 0, Node *p = nullptr) : val(v), parent(p) {
        left = right = nullptr;
    }
};

Node *lca(Node *n1, Node *n2) {
    map<Node *, bool> vis;
    while (n1) {
        vis[n1] = true;
        n1 = n1->parent;
    }
    while (n2) {
        if (vis[n2]) return n2;
        n2 = n2->parent;
    }
    return nullptr;
}

void delTree(Node *root) {
    if (!root) return;
    delTree(root->left);
    delTree(root->right);
    delete root;
}

int main(void) {
    // init the tree
    Node *root = new Node(1);
    root->left = new Node(2, root);
    root->right = new Node(3, root);
    root->left->left = new Node(4, root->left);
    root->left->right = new Node(5, root->left);
    root->right->left = new Node(6, root->right);
    root->right->right = new Node(7, root->right);

    // get LCA
    Node *n1 = root->left->left, *n2 = root->left->right;
    Node *ans;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = root->right->left, n2 = root->left->right;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = nullptr, n2 = nullptr;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = root->left, n2 = root->left->right;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = root, n2 = root->left->left;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1->val, n2->val, (ans = lca(n1, n2)) ? ans->val : -1);

    // destroy tree
    delTree(root);

    return 0;
}

 

(3)我们甚至可以不需要使用map这样的额外存储,当存储有parent位置后,我们需要的仅仅是找到它们路径上第一个共同的点是哪个,遇到问题仅仅是,由于两个点深度不同,倒推回根节点的过程中,它们“路过”那个要找的LCA结点的“时间”可能不同,这个问题其实很好解决,我们只要让它们的深度预处理到相同的就可以,我们只要将较深的结点先向前推几层,这个层数取决于两个结点在树中的层数差值,之后两个结点同时倒推,遇到的第一个公共点就是LCA结点。

代码:

#include <iostream>
#include <map>
#include <vector>

using namespace std;

struct Node {
    int val;
    Node *left, *right, *parent;
    Node(int v = 0, Node *p = nullptr) : val(v), parent(p) {
        left = right = nullptr;
    }
};

int getDepth(Node *n) {
    int ret = -1;
    while (n) {
        ret++;
        n = n->parent;
    }
    return ret;
}

Node *lca(Node *n1, Node *n2) {
    int d1 = getDepth(n1), d2 = getDepth(n2);
    if (d1 < 0 || d2 < 0) return nullptr;
    int diff = d1 - d2;
    if (diff < 0) {
        diff = -diff;
        Node *tmp = n1;
        n1 = n2;
        n2 = tmp;
    }
    while (diff--) n1 = n1->parent;
    while (n1 && n2) {
        if (n1 == n2) return n1;
        n1 = n1->parent, n2 = n2->parent;
    }
    return nullptr;
}

void delTree(Node *root) {
    if (!root) return;
    delTree(root->left);
    delTree(root->right);
    delete root;
}

int main(void) {
    // init the tree
    Node *root = new Node(1);
    root->left = new Node(2, root);
    root->right = new Node(3, root);
    root->left->left = new Node(4, root->left);
    root->left->right = new Node(5, root->left);
    root->right->left = new Node(6, root->right);
    root->right->right = new Node(7, root->right);

    // get LCA
    Node *n1 = root->left->left, *n2 = root->left->right;
    Node *ans;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = root->right->left, n2 = root->left->right;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = nullptr, n2 = nullptr;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = root->left, n2 = root->left->right;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1 ? n1->val : -1, n2 ? n2->val : -1, (ans = lca(n1, n2)) ? ans->val : -1);

    n1 = root, n2 = root->left->left;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1->val, n2->val, (ans = lca(n1, n2)) ? ans->val : -1);

    // destroy tree
    delTree(root);

    return 0;
}

 

第二种,和刚才BST中的解决策略类似,同样是考虑到上面所述的四种可能,只是我们不具有BST那样小的必然在左子树、大者必然在右子树那样的性质,于是我们需要遍历这棵树来解决,复杂度O(V)。

#include <iostream>
#include <vector>

using namespace std;

struct Node {
    int val;
    Node *left, *right;
    Node(int v = 0) : val(v) {
        left = right = nullptr;
    }
};

Node *lca(Node *root, int n1, int n2) {
    if (!root) return NULL;
    if (root->val == n1 || root->val == n2)
        return root;
    Node *left = lca(root->left, n1, n2);
    Node *right = lca(root->right, n1, n2);
    if (left && right) return root;
    return left ? left : right;
}

void delTree(Node *root) {
    if (!root) return;
    delTree(root->left);
    delTree(root->right);
    delete root;
}

int main(void) {
    // init the tree
    Node *root = new Node(1);
    root->left = new Node(2);
    root->right = new Node(3);
    root->left->left = new Node(4);
    root->left->right = new Node(5);
    root->right->left = new Node(6);
    root->right->right = new Node(7);

    // get LCA
    int n1 = 4, n2 = 5;
    Node *ans;
    printf("LCA of %d and %d is %d\n", n1, n2, (ans = lca(root, n1, n2)) ? ans->val : -1);

    n1 = 6, n2 = 5;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1, n2, (ans = lca(root, n1, n2)) ? ans->val : -1);

    n1 = 8, n2 = 9;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1, n2, (ans = lca(root, n1, n2)) ? ans->val : -1);

    n1 = 2, n2 = 5;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1, n2, (ans = lca(root, n1, n2)) ? ans->val : -1);

    n1 = 1, n2 = 4;
    ans = NULL;
    printf("LCA of %d and %d is %d\n", n1, n2, (ans = lca(root, n1, n2)) ? ans->val : -1);

    // destroy tree
    delTree(root);

    return 0;
}

 

【CodeForces-400D】Dima and Bacteria【Disjoint Set + Floyd Warshall Algorithm】

【CodeForces-400D】Dima and Bacteria【Disjoint Set + Floyd Warshall Algorithm】

代码:

#include <iostream>
#include <vector>

using namespace std;
typedef long long LL;
struct Edge {
    int to, w;
};
const int NLIM = 1e5 + 10;
const int SLIM = 500 + 10;
const int INF = 0xfffffff;
vector<Edge> adjOfNode[NLIM];
int parent[NLIM];
int ns[NLIM];       /* Node Set */
int c[SLIM];
int adjOfSet[SLIM][SLIM];

inline int mini(int a, int b) {return a < b ? a : b;}

int Find(int x) {
    if (parent[x] != x) parent[x] = Find(parent[x]);
    return parent[x];
}

void Union(int x, int y) {
    int nx = Find(x);
    int ny = Find(y);
    if (nx != ny) parent[nx] = parent[ny];
}

int main(void) {
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    int n, m, k;
    cin >> n >> m >> k;
    for (int i = 0; i < k; i++) {
        cin >> c[i];
        if (i) c[i] += c[i - 1];
    }
    for (int i = 0; i < k; i++) {
        int crt = i ? c[i - 1] + 1 : 1;
        while (crt <= c[i]) {
            // save each node belong to which set
            ns[crt] = i;
            // init parent array
            parent[crt] = crt;
            crt++;
        }
    }
    // init adjOfSet
    for (int i = 0; i < k; i++) {
        for (int j = 0; j < k; j++) {
            adjOfSet[i][j] = INF;
        }
    }
    // load the adjacent list and matrix
    int u, v;
    Edge tmp;
    for (int i = 0; i < m; i++) {
        cin >> u >> v >> tmp.w;
        tmp.to = v;
        adjOfNode[u].push_back(tmp);
        tmp.to = u;
        adjOfNode[v].push_back(tmp);
        // cout << u << " " << v << " " << ns[u] << " " << ns[v] << " " << tmp.w << endl;
        adjOfSet[ns[u]][ns[v]] = mini(adjOfSet[ns[u]][ns[v]], tmp.w);
        adjOfSet[ns[v]][ns[u]] = mini(adjOfSet[ns[v]][ns[u]], tmp.w);
    }
    // use union-find set to check validity
    bool flag = true;
    for (int i = 0; i < k; i++) {
        int crt = i ? c[i - 1] + 1 : 1;
        if (c[i] == crt) {
            adjOfSet[ns[crt]][ns[crt]] = 0;
        }
        while (crt <= c[i]) {
            for (int j = 0; j < adjOfNode[crt].size(); j++) {
                if (adjOfNode[crt][j].w == 0)
                    Union(crt, adjOfNode[crt][j].to);
            }
            crt++;
        }
        crt = i ? c[i - 1] + 1 : 1;
        int pre = Find(crt++);
        for (; crt <= c[i]; crt++) {
            if (Find(crt) != pre) break;
        }
        if (crt <= c[i]) {
            flag = false;
            break;
        }
    }
    if (flag) {
        // for (int i = 0; i < k; i++) {
        //     for (int j = 0; j < k; j++) {
        //         if (j) cout << ' ';
        //         cout << adjOfSet[i][j];
        //     }
        //     cout << endl;
        // }
        cout << "Yes\n";
        // run the FloydDP
        for (int r = 0; r < k; r++) {
            for (int i = 0; i < k; i++) {
                for (int j = 0; j < k; j++) {
                    adjOfSet[i][j] = mini(adjOfSet[i][j], adjOfSet[i][r] + adjOfSet[r][j]);
                }
            }
        }
        for (int i = 0; i < k; i++) {
            for (int j = 0; j < k; j++) {
                if (j) cout << " ";
                cout << (adjOfSet[i][j] == INF ? -1 : adjOfSet[i][j]);
            }
            cout << "\n";
        }
    } else {
        cout << "No\n";
    }
    return 0;
}